Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 117 additions & 18 deletions notebooks/workflow_example.ipynb

Large diffs are not rendered by default.

81 changes: 75 additions & 6 deletions pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,86 @@ def _ensure_node_is_not_duplicated(self, node: Node, label: str):
)
del self.nodes[node.label]

def remove(self, node: Node | str):
if isinstance(node, Node):
node.parent = None
node.disconnect()
del self.nodes[node.label]
def remove(self, node: Node | str) -> list[tuple[Channel, Channel]]:
"""
Remove a node from the `nodes` collection, disconnecting it and setting its
`parent` to None.

Args:
node (Node|str): The node (or its label) to remove.

Returns:
(list[tuple[Channel, Channel]]): Any connections that node had.
"""
node = self.nodes[node] if isinstance(node, str) else node
node.parent = None
disconnected = node.disconnect()
if node in self.starting_nodes:
self.starting_nodes.remove(node)
del self.nodes[node.label]
return disconnected

def replace(self, owned_node: Node | str, replacement: Node | type[Node]):
"""
Replaces a node currently owned with a new node instance.
The replacement must not belong to any other parent or have any connections.
The IO of the new node must be a perfect superset of the replaced node, i.e.
channel labels need to match precisely, but additional channels may be present.
After replacement, the new node will have the old node's connections, label,
and belong to this composite.

Args:
owned_node (Node|str): The node to replace or its label.
replacement (Node | type[Node]): The node or class to replace it with. (If
a class is passed, it has all the same requirements on IO compatibility
and simply gets instantiated.)

Returns:
(Node): The node that got removed
"""
if isinstance(owned_node, str):
owned_node = self.nodes[owned_node]

if owned_node.parent is not self:
raise ValueError(
f"The node being replaced should be a child of this composite, but "
f"another parent was found: {owned_node.parent}"
)

if isinstance(replacement, Node):
if replacement.parent is not None:
raise ValueError(
f"Replacement node must have no parent, but got "
f"{replacement.parent}"
)
if replacement.connected:
raise ValueError("Replacement node must not have any connections")
elif issubclass(replacement, Node):
replacement = replacement(label=owned_node.label)
else:
del self.nodes[node]
raise TypeError(
f"Expected replacement node to be a node instance or node subclass, but "
f"got {replacement}"
)

replacement.copy_connections(owned_node)
replacement.label = owned_node.label
is_starting_node = owned_node in self.starting_nodes
self.remove(owned_node)
self.add(replacement)
if is_starting_node:
self.starting_nodes.append(replacement)

def __setattr__(self, key: str, node: Node):
if isinstance(node, Node) and key != "parent":
self.add(node, label=key)
elif (
isinstance(node, type)
and issubclass(node, Node)
and key in self.nodes.keys()
):
# When a class is assigned to an existing node, try a replacement
self.replace(key, node)
else:
super().__setattr__(key, node)

Expand Down
36 changes: 35 additions & 1 deletion pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if TYPE_CHECKING:
from bidict import bidict

from pyiron_workflow.node import Node


class Macro(Composite):
"""
Expand Down Expand Up @@ -114,7 +116,7 @@ class Macro(Composite):
... macro.b = macro.create.SingleValue(add_one, x=0)
... macro.c = macro.create.SingleValue(add_one, x=0)
>>>
>>> m = Macro(modified_start_macro)
>>> m = Macro(modified_flow_macro)
>>> m.outputs.to_value_dict()
>>> m(a__x=1, b__x=2, c__x=3)
{'a__result': 2, 'b__result': 3, 'c__result': 4}
Expand All @@ -133,6 +135,27 @@ class Macro(Composite):
Manually controlling execution flow is necessary for cyclic graphs (cf. the
while loop meta-node), but best to avoid when possible as it's easy to miss
intended connections in complex graphs.

We can also modify an existing macro at runtime by replacing nodes within it, as
long as the replacement has fully compatible IO. There are three syntacic ways
to do this. Let's explore these by going back to our `add_three_macro` and
replacing each of its children with a node that adds 2 instead of 1.
>>> @Macro.wrap_as.single_value_node()
... def add_two(x):
... result = x + 2
... return result
>>>
>>> adds_six_macro = Macro(add_three_macro)
>>> # With the replace method
>>> # (replacement target can be specified by label or instance,
>>> # the replacing node can be specified by instance or class)
>>> adds_six_macro.replace(adds_six_macro.one, add_two())
>>> # With the replace_with method
>>> adds_six_macro.two.replace_with(add_two())
>>> # And by assignment of a compatible class to an occupied node label
>>> adds_six_macro.three = add_two
>>> adds_six_macro(inp=1)
{'three__result': 7}
"""

def __init__(
Expand Down Expand Up @@ -169,6 +192,10 @@ def inputs(self) -> Inputs:
def outputs(self) -> Outputs:
return self._outputs

def _rebuild_data_io(self):
self._inputs = self._build_inputs()
self._outputs = self._build_outputs()

def _configure_graph_execution(self):
run_signals = self.disconnect_run()

Expand All @@ -195,6 +222,13 @@ def _reconnect_run(self, run_signal_pairs_to_restore):
for pairs in run_signal_pairs_to_restore:
pairs[0].connect(pairs[1])

def replace(self, owned_node: Node | str, replacement: Node | type[Node]):
super().replace(owned_node=owned_node, replacement=replacement)
# Make sure node-level IO is pointing to the new node
self._rebuild_data_io()
# This is brute-force overkill since only the replaced node needs to be updated
# but it's not particularly expensive

def to_workfow(self):
raise NotImplementedError

Expand Down
50 changes: 49 additions & 1 deletion pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import graphviz

from pyiron_workflow.composite import Composite
from pyiron_workflow.io import Inputs, Outputs
from pyiron_workflow.io import IO, Inputs, Outputs


def manage_status(node_method):
Expand Down Expand Up @@ -447,3 +447,51 @@ def get_first_shared_parent(self, other: Node) -> Composite | None:
our = our.parent
their = other
return None

def copy_connections(self, other: Node) -> None:
"""
Copies all the connections in another node to this one.
Expects the channels available on this node to be commensurate to those on the
other, i.e. same label, compatible type hint for the connections that exist.
This node may freely have additional channels not present in the other node.

If an exception is encountered, any connections copied so far are disconnected

Args:
other (Node): the node whose connections should be copied.
"""
new_connections = []
try:
for my_panel, other_panel in [
(self.inputs, other.inputs),
(self.outputs, other.outputs),
(self.signals.input, other.signals.input),
(self.signals.output, other.signals.output),
]:
for key, channel in other_panel.items():
for target in channel.connections:
my_panel[key].connect(target)
new_connections.append((my_panel[key], target))
except Exception as e:
# If you run into trouble, unwind what you've done
for connection in new_connections:
connection[0].disconnect(connection[1])
raise e

def replace_with(self, other: Node | type[Node]):
"""
If this node has a parent, invokes `self.parent.replace(self, other)` to swap
out this node for the other node in the parent graph.

The replacement must have fully compatible IO, i.e. its IO must be a superset of
this node's IO with all the same labels and type hints (although the latter is
not strictly enforced and will only cause trouble if there is an incompatibility
that causes trouble in the process of copying over connections)

Args:
other (Node|type[Node]): The replacement.
"""
if self.parent is not None:
self.parent.replace(self, other)
else:
warnings.warn(f"Could not replace {self.label}, as it has no parent.")
28 changes: 28 additions & 0 deletions tests/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,34 @@ def test_return_value(self):
node.run()
node.future.result() # Wait for the remote execution to finish

def test_copy_connections(self):
node = Function(plus_one)

upstream = Function(plus_one)
to_copy = Function(plus_one, x=upstream.outputs.y)
downstream = Function(plus_one, x=to_copy.outputs.y)
upstream > to_copy > downstream

wrong_io = Function(no_default, x=upstream.outputs.y)
downstream.inputs.x.connect(wrong_io.outputs["x + y + 1"])

with self.subTest("Ensure failed copies fail cleanly"):
with self.assertRaises(AttributeError):
node.copy_connections(wrong_io)
self.assertFalse(
node.connected,
msg="The x-input connection should have been copied, but should be "
"removed when the copy fails."
)
node.disconnect() # Make sure you've got a clean slate

with self.subTest("Successful copy"):
node.copy_connections(to_copy)
self.assertIn(upstream.outputs.y, node.inputs.x.connections)
self.assertIn(upstream.signals.output.ran, node.signals.input.run)
self.assertIn(downstream.inputs.x, node.outputs.y.connections)
self.assertIn(downstream.signals.input.run, node.signals.output.ran)


@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestSingleValue(unittest.TestCase):
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,136 @@ def only_starting(macro):
with self.assertRaises(ValueError):
Macro(only_starting)

def test_replace_node(self):
macro = Macro(add_three_macro)

adds_three_node = Macro(
add_three_macro,
inputs_map={"one__x": "x"},
outputs_map={"three__result": "result"}
)
adds_one_node = macro.two

self.assertEqual(
macro(one__x=0).three__result,
3,
msg="Sanity check"
)

with self.subTest("Verify successful cases"):

macro.replace(adds_one_node, adds_three_node)
self.assertEqual(
macro(one__x=0).three__result,
5,
msg="Result should be bigger after replacing an add_one node with an "
"add_three macro"
)
self.assertFalse(
adds_one_node.connected,
msg="Replaced node should get disconnected"
)
self.assertIsNone(
adds_one_node.parent,
msg="Replaced node should get orphaned"
)

add_one_class = macro.wrap_as.single_value_node()(add_one)
self.assertTrue(issubclass(add_one_class, SingleValue), msg="Sanity check")
macro.replace(adds_three_node, add_one_class)
self.assertEqual(
macro(one__x=0).three__result,
3,
msg="Should be possible to replace with a class instead of an instance"
)

macro.replace("two", adds_three_node)
self.assertEqual(
macro(one__x=0).three__result,
5,
msg="Should be possible to replace by label"
)

macro.two.replace_with(adds_one_node)
self.assertEqual(
macro(one__x=0).three__result,
3,
msg="Nodes should have syntactic sugar for invoking replacement"
)

@Macro.wrap_as.function_node()
def add_two(x):
result = x + 2
return result
macro.two = add_two
self.assertEqual(
macro(one__x=0).three__result,
4,
msg="Composite should allow replacement when a class is assigned"
)

self.assertListEqual(
macro.starting_nodes,
[macro.one],
msg="Sanity check"
)
new_starter = add_two()
macro.one.replace_with(new_starter)
self.assertListEqual(
macro.starting_nodes,
[new_starter],
msg="Replacement should be reflected in the starting nodes"
)
self.assertIs(
macro.inputs.one__x,
new_starter.inputs.x,
msg="Replacement should be reflected in composite IO"
)

with self.subTest("Verify failure cases"):
another_macro = Macro(add_three_macro)
another_node = Macro(
add_three_macro,
inputs_map={"one__x": "x"},
outputs_map={"three__result": "result"},
)
another_macro.now_its_a_child = another_node

with self.assertRaises(
ValueError,
msg="Should fail when replacement has a parent"
):
macro.replace(macro.two, another_node)

another_macro.remove(another_node)
another_node.inputs.x = another_macro.outputs.three__result
with self.assertRaises(
ValueError,
msg="Should fail when replacement is connected"
):
macro.replace(macro.two, another_node)

another_node.disconnect()
an_ok_replacement = another_macro.two
another_macro.remove(an_ok_replacement)
with self.assertRaises(
ValueError,
msg="Should fail if the node being replaced isn't a child"
):
macro.replace(another_node, an_ok_replacement)

@Macro.wrap_as.function_node()
def add_two_incompatible_io(not_x):
result_is_not_my_name = not_x + 2
return result_is_not_my_name

with self.assertRaises(
AttributeError,
msg="Replacing via class assignment should fail if the class has "
"incompatible IO"
):
macro.two = add_two_incompatible_io


if __name__ == '__main__':
unittest.main()
Loading