From dd635bf9516b81e4baf87c6d9f584987efb7479b Mon Sep 17 00:00:00 2001 From: Rob Cheung Date: Tue, 23 Jul 2024 10:16:46 -0400 Subject: [PATCH 1/3] updates --- substrate/core/corenode.py | 8 ++++++-- substrate/core/future_directive.py | 2 +- substrate/substrate.py | 9 ++++++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/substrate/core/corenode.py b/substrate/core/corenode.py index 79edffe..0ee558c 100644 --- a/substrate/core/corenode.py +++ b/substrate/core/corenode.py @@ -24,7 +24,7 @@ def __init__( _cache_age: Optional[int] = None, _cache_keys: Optional[List[str]] = None, _max_retries: Optional[int] = None, - _depends: List["CoreNode"] = [], + _depends=None, **attr, ): self._out_type = out_type @@ -35,7 +35,7 @@ def __init__( self._cache_age = _cache_age self._cache_keys = _cache_keys self._max_retries = _max_retries - self._depends = _depends + self._depends = [] if _depends is None else _depends self._should_output_globally: bool = not hide self.SG = nx.DiGraph() if attr: @@ -54,6 +54,10 @@ def __init__( for referenced_future in depend_node.futures_from_args: self.futures_from_args.append(referenced_future) + @property + def explicit_depends(self) -> List["CoreNode"]: + return self._depends + @property def dependent_futures(self) -> List[Future]: return find_futures_client(self.init_attrs) diff --git a/substrate/core/future_directive.py b/substrate/core/future_directive.py index 349af21..5a9dc27 100644 --- a/substrate/core/future_directive.py +++ b/substrate/core/future_directive.py @@ -89,7 +89,7 @@ class TraceOperation: @dataclass class TraceDirective(BaseDirective): op_stack: List[TraceOperation] - origin_node: Any # Should be CoreNode, but am running into circular import + origin_node: "CoreNode" type: Literal["trace"] = "trace" def to_dict(self) -> Dict: diff --git a/substrate/substrate.py b/substrate/substrate.py index f0f1fb8..865cad8 100644 --- a/substrate/substrate.py +++ b/substrate/substrate.py @@ -1,7 +1,8 @@ import json +from typing import Optional, Any + import zlib import base64 -from typing import Any, Dict from substrate.streaming import SubstrateStreamingResponse @@ -21,11 +22,13 @@ def __init__( api_key: str, base_url: str = "https://api.substrate.run", timeout: float = 60 * 5.0, - additional_headers: Dict[str, Any] = {}, + additional_headers: Optional[str, Any] = None, ): """ Initialize the Substrate SDK. """ + if additional_headers is None: + additional_headers = {} self.api_key = api_key self._client = APIClient( api_key=api_key, @@ -98,7 +101,7 @@ def collect_nodes(node): for node in all_nodes: if not graph.DAG.has_node(node): graph.add_node(node) - for depend_node in node._depends: + for depend_node in node.explicit_depends: graph.add_edge(depend_node, node) graph_serialized = graph.to_dict() return graph_serialized From b7fec82ce934d7571d059a9df5176fbfc06216b9 Mon Sep 17 00:00:00 2001 From: Rob Cheung Date: Tue, 23 Jul 2024 10:19:21 -0400 Subject: [PATCH 2/3] type --- substrate/substrate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/substrate/substrate.py b/substrate/substrate.py index 865cad8..1fa19d8 100644 --- a/substrate/substrate.py +++ b/substrate/substrate.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Any +from typing import Optional, Any, Dict import zlib import base64 @@ -22,7 +22,7 @@ def __init__( api_key: str, base_url: str = "https://api.substrate.run", timeout: float = 60 * 5.0, - additional_headers: Optional[str, Any] = None, + additional_headers: Optional[Dict[str, Any]] = None, ): """ Initialize the Substrate SDK. From 028b6cb3d61bf809e529fcf0ccde4c56315d9d68 Mon Sep 17 00:00:00 2001 From: Rob Cheung Date: Tue, 23 Jul 2024 10:25:18 -0400 Subject: [PATCH 3/3] types --- substrate/core/corenode.py | 2 +- substrate/run_python.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/substrate/core/corenode.py b/substrate/core/corenode.py index 0ee558c..638da64 100644 --- a/substrate/core/corenode.py +++ b/substrate/core/corenode.py @@ -24,7 +24,7 @@ def __init__( _cache_age: Optional[int] = None, _cache_keys: Optional[List[str]] = None, _max_retries: Optional[int] = None, - _depends=None, + _depends: Optional[List["CoreNode"]] = None, **attr, ): self._out_type = out_type diff --git a/substrate/run_python.py b/substrate/run_python.py index 5ef588b..f8eaca7 100644 --- a/substrate/run_python.py +++ b/substrate/run_python.py @@ -11,13 +11,13 @@ class RunPython(CoreNode[RunPythonOut]): def __init__( self, function: Callable, - kwargs: Dict[str, Any] = {}, + kwargs=None, pip_install: Optional[List[str]] = None, hide: bool = False, _cache_age: Optional[int] = None, _cache_keys: Optional[List[str]] = None, _max_retries: Optional[int] = None, - _depends: List[CoreNode] = [], + _depends: Optional[List[CoreNode]] = None, ): """ Args: @@ -41,7 +41,7 @@ def __init__( python_version = sys.version.split()[0] super().__init__( pkl_function=fn_str, - kwargs=kwargs, + kwargs={} if kwargs is None else kwargs, pip_install=pip_install, hide=hide, python_version=python_version,