Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,20 +522,30 @@ def _before_run(
self.inputs.fetch()

if self.use_cache and self.cache_hit: # Read and use cache
self._on_cache_hit()
if self.parent is None and emit_ran_signal:
self.emit()
elif self.parent is not None:
self.parent.register_child_starting(self)
self.parent.register_child_finished(self)
if emit_ran_signal:
self.parent.register_child_emitting(self)

return True, self._outputs_to_run_return()
elif self.use_cache: # Write cache and continue
self._cached_inputs = self.inputs.to_value_dict()
else:
self._on_cache_miss()
if self.use_cache: # Write cache and continue
self._cached_inputs = self.inputs.to_value_dict()

return super()._before_run(check_readiness=check_readiness)

def _on_cache_hit(self) -> None:
"""A hook for subclasses to act on cache hits"""
return

def _on_cache_miss(self) -> None:
"""A hook for subclasses to act on cache misses"""
return

def _run(
self,
executor: Executor | None,
Expand Down
6 changes: 3 additions & 3 deletions pyiron_workflow/nodes/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def _setup_node(self) -> None:
self.starting_nodes = input_nodes
self._input_node_labels = tuple(n.label for n in input_nodes)

def _on_run(self):
self._build_body()
return super()._on_run()
def _on_cache_miss(self) -> None:
if self.ready:
self._build_body()

def _build_body(self):
"""
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/nodes/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyiron_snippets.dotdict import DotDict

from pyiron_workflow._tests import ensure_tests_in_python_path
from pyiron_workflow.mixin.run import ReadinessError
from pyiron_workflow.nodes.for_loop import (
MapsToNonexistentOutputError,
UnmappedConflictError,
Expand Down Expand Up @@ -613,6 +614,38 @@ def test_executor_deserialization(self):
finally:
n.delete_storage()

def test_caching(self):
side_effect_counter = 0

@as_function_node("m")
def SideEffectNode(n: int):
nonlocal side_effect_counter
side_effect_counter += 1
return n**2

n = [1, 2, 3, 4]
s = SideEffectNode.for_node(iter_on="n")
with self.assertRaises(
ReadinessError,
msg="Without input, we should raise a readiness error before we get to "
"building the body node",
):
s()

s.run(n=n)
self.assertEqual(
side_effect_counter,
len(n),
msg="Sanity check, it should have run once for each child node",
)
s.run()
self.assertEqual(
side_effect_counter,
len(n),
msg="With identical input, children should only actually get run the first "
"time",
)


if __name__ == "__main__":
unittest.main()
Loading