From cf9f1cf3ae19cf61a2dabd5fd3071d7460f29aa5 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 5 Mar 2025 10:20:25 -0800 Subject: [PATCH 1/4] Add hook functions for hitting/missing the cache Signed-off-by: liamhuber --- pyiron_workflow/node.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index bed659ba0..8073c7c43 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -522,6 +522,7 @@ def _before_run( self.inputs.fetch() if self.use_cache and self.cache_hit: # Read and use cache + self._on_cache_hit() if self.parent is None and emit_ran_signal: self.emit() elif self.parent is not None: @@ -529,13 +530,22 @@ def _before_run( self.parent.register_child_finished(self) if emit_ran_signal: self.parent.register_child_emitting(self) - return True, self._outputs_to_run_return() - elif self.use_cache: # Write cache and continue - self._cached_inputs = self.inputs.to_value_dict() + else: + self._on_cache_miss() + if self.use_cache: # Write cache and continue + self._cached_inputs = self.inputs.to_value_dict() return super()._before_run(check_readiness=check_readiness) + def _on_cache_hit(self) -> None: + """A hook for subclasses to act on cache hits""" + return + + def _on_cache_miss(self) -> None: + """A hook for subclasses to act on cache misses""" + return + def _run( self, executor: Executor | None, From 01814b8a4021ccdb504d6bfbcc8553fba33a99c4 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 5 Mar 2025 10:20:35 -0800 Subject: [PATCH 2/4] Only rebuild for loops on a missed cache Signed-off-by: liamhuber --- pyiron_workflow/nodes/for_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index 7ce94e273..f2346048e 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -232,9 +232,8 @@ def _setup_node(self) -> None: self.starting_nodes = input_nodes self._input_node_labels = tuple(n.label for n in input_nodes) - def _on_run(self): + def _on_cache_miss(self) -> None: self._build_body() - return super()._on_run() def _build_body(self): """ From 988c67049c2cddcf32800a77addd097c9bdec2a8 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 5 Mar 2025 17:44:53 -0800 Subject: [PATCH 3/4] Only build for ready nodes Signed-off-by: liamhuber --- pyiron_workflow/nodes/for_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index f2346048e..9e6abf661 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -233,7 +233,8 @@ def _setup_node(self) -> None: self._input_node_labels = tuple(n.label for n in input_nodes) def _on_cache_miss(self) -> None: - self._build_body() + if self.ready: + self._build_body() def _build_body(self): """ From 659c8fbf787d420ba58e570ee40568a70b3b3f17 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 5 Mar 2025 17:45:38 -0800 Subject: [PATCH 4/4] Add test for caching Co-Authored-By: Marvin Poul Signed-off-by: liamhuber --- tests/unit/nodes/test_for_loop.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/unit/nodes/test_for_loop.py b/tests/unit/nodes/test_for_loop.py index 2efb5e659..c4e056832 100644 --- a/tests/unit/nodes/test_for_loop.py +++ b/tests/unit/nodes/test_for_loop.py @@ -8,6 +8,7 @@ from pyiron_snippets.dotdict import DotDict from pyiron_workflow._tests import ensure_tests_in_python_path +from pyiron_workflow.mixin.run import ReadinessError from pyiron_workflow.nodes.for_loop import ( MapsToNonexistentOutputError, UnmappedConflictError, @@ -613,6 +614,38 @@ def test_executor_deserialization(self): finally: n.delete_storage() + def test_caching(self): + side_effect_counter = 0 + + @as_function_node("m") + def SideEffectNode(n: int): + nonlocal side_effect_counter + side_effect_counter += 1 + return n**2 + + n = [1, 2, 3, 4] + s = SideEffectNode.for_node(iter_on="n") + with self.assertRaises( + ReadinessError, + msg="Without input, we should raise a readiness error before we get to " + "building the body node", + ): + s() + + s.run(n=n) + self.assertEqual( + side_effect_counter, + len(n), + msg="Sanity check, it should have run once for each child node", + ) + s.run() + self.assertEqual( + side_effect_counter, + len(n), + msg="With identical input, children should only actually get run the first " + "time", + ) + if __name__ == "__main__": unittest.main()