diff --git a/substrate/core/corenode.py b/substrate/core/corenode.py index 79edffe..638da64 100644 --- a/substrate/core/corenode.py +++ b/substrate/core/corenode.py @@ -24,7 +24,7 @@ def __init__( _cache_age: Optional[int] = None, _cache_keys: Optional[List[str]] = None, _max_retries: Optional[int] = None, - _depends: List["CoreNode"] = [], + _depends: Optional[List["CoreNode"]] = None, **attr, ): self._out_type = out_type @@ -35,7 +35,7 @@ def __init__( self._cache_age = _cache_age self._cache_keys = _cache_keys self._max_retries = _max_retries - self._depends = _depends + self._depends = [] if _depends is None else _depends self._should_output_globally: bool = not hide self.SG = nx.DiGraph() if attr: @@ -54,6 +54,10 @@ def __init__( for referenced_future in depend_node.futures_from_args: self.futures_from_args.append(referenced_future) + @property + def explicit_depends(self) -> List["CoreNode"]: + return self._depends + @property def dependent_futures(self) -> List[Future]: return find_futures_client(self.init_attrs) diff --git a/substrate/core/future_directive.py b/substrate/core/future_directive.py index 349af21..5a9dc27 100644 --- a/substrate/core/future_directive.py +++ b/substrate/core/future_directive.py @@ -89,7 +89,7 @@ class TraceOperation: @dataclass class TraceDirective(BaseDirective): op_stack: List[TraceOperation] - origin_node: Any # Should be CoreNode, but am running into circular import + origin_node: "CoreNode" type: Literal["trace"] = "trace" def to_dict(self) -> Dict: diff --git a/substrate/run_python.py b/substrate/run_python.py index 5ef588b..f8eaca7 100644 --- a/substrate/run_python.py +++ b/substrate/run_python.py @@ -11,13 +11,13 @@ class RunPython(CoreNode[RunPythonOut]): def __init__( self, function: Callable, - kwargs: Dict[str, Any] = {}, + kwargs=None, pip_install: Optional[List[str]] = None, hide: bool = False, _cache_age: Optional[int] = None, _cache_keys: Optional[List[str]] = None, _max_retries: Optional[int] = None, - _depends: List[CoreNode] = [], + _depends: Optional[List[CoreNode]] = None, ): """ Args: @@ -41,7 +41,7 @@ def __init__( python_version = sys.version.split()[0] super().__init__( pkl_function=fn_str, - kwargs=kwargs, + kwargs={} if kwargs is None else kwargs, pip_install=pip_install, hide=hide, python_version=python_version, diff --git a/substrate/substrate.py b/substrate/substrate.py index f0f1fb8..1fa19d8 100644 --- a/substrate/substrate.py +++ b/substrate/substrate.py @@ -1,7 +1,8 @@ import json +from typing import Optional, Any, Dict + import zlib import base64 -from typing import Any, Dict from substrate.streaming import SubstrateStreamingResponse @@ -21,11 +22,13 @@ def __init__( api_key: str, base_url: str = "https://api.substrate.run", timeout: float = 60 * 5.0, - additional_headers: Dict[str, Any] = {}, + additional_headers: Optional[Dict[str, Any]] = None, ): """ Initialize the Substrate SDK. """ + if additional_headers is None: + additional_headers = {} self.api_key = api_key self._client = APIClient( api_key=api_key, @@ -98,7 +101,7 @@ def collect_nodes(node): for node in all_nodes: if not graph.DAG.has_node(node): graph.add_node(node) - for depend_node in node._depends: + for depend_node in node.explicit_depends: graph.add_edge(depend_node, node) graph_serialized = graph.to_dict() return graph_serialized