Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions substrate/core/corenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
_cache_age: Optional[int] = None,
_cache_keys: Optional[List[str]] = None,
_max_retries: Optional[int] = None,
_depends: List["CoreNode"] = [],
_depends: Optional[List["CoreNode"]] = None,
**attr,
):
self._out_type = out_type
Expand All @@ -35,7 +35,7 @@ def __init__(
self._cache_age = _cache_age
self._cache_keys = _cache_keys
self._max_retries = _max_retries
self._depends = _depends
self._depends = [] if _depends is None else _depends
self._should_output_globally: bool = not hide
self.SG = nx.DiGraph()
if attr:
Expand All @@ -54,6 +54,10 @@ def __init__(
for referenced_future in depend_node.futures_from_args:
self.futures_from_args.append(referenced_future)

@property
def explicit_depends(self) -> List["CoreNode"]:
return self._depends

@property
def dependent_futures(self) -> List[Future]:
return find_futures_client(self.init_attrs)
Expand Down
2 changes: 1 addition & 1 deletion substrate/core/future_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class TraceOperation:
@dataclass
class TraceDirective(BaseDirective):
op_stack: List[TraceOperation]
origin_node: Any # Should be CoreNode, but am running into circular import
origin_node: "CoreNode"
type: Literal["trace"] = "trace"

def to_dict(self) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions substrate/run_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class RunPython(CoreNode[RunPythonOut]):
def __init__(
self,
function: Callable,
kwargs: Dict[str, Any] = {},
kwargs=None,
pip_install: Optional[List[str]] = None,
hide: bool = False,
_cache_age: Optional[int] = None,
_cache_keys: Optional[List[str]] = None,
_max_retries: Optional[int] = None,
_depends: List[CoreNode] = [],
_depends: Optional[List[CoreNode]] = None,
):
"""
Args:
Expand All @@ -41,7 +41,7 @@ def __init__(
python_version = sys.version.split()[0]
super().__init__(
pkl_function=fn_str,
kwargs=kwargs,
kwargs={} if kwargs is None else kwargs,
pip_install=pip_install,
hide=hide,
python_version=python_version,
Expand Down
9 changes: 6 additions & 3 deletions substrate/substrate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from typing import Optional, Any, Dict

import zlib
import base64
from typing import Any, Dict

from substrate.streaming import SubstrateStreamingResponse

Expand All @@ -21,11 +22,13 @@ def __init__(
api_key: str,
base_url: str = "https://api.substrate.run",
timeout: float = 60 * 5.0,
additional_headers: Dict[str, Any] = {},
additional_headers: Optional[Dict[str, Any]] = None,
):
"""
Initialize the Substrate SDK.
"""
if additional_headers is None:
additional_headers = {}
self.api_key = api_key
self._client = APIClient(
api_key=api_key,
Expand Down Expand Up @@ -98,7 +101,7 @@ def collect_nodes(node):
for node in all_nodes:
if not graph.DAG.has_node(node):
graph.add_node(node)
for depend_node in node._depends:
for depend_node in node.explicit_depends:
graph.add_edge(depend_node, node)
graph_serialized = graph.to_dict()
return graph_serialized