Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 47 additions & 44 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import inspect
from warnings import warn

from pyiron_workflow.has_interface_mixins import HasChannel, UsesState
from pyiron_workflow.has_interface_mixins import HasChannel, HasLabel, UsesState
from pyiron_workflow.has_to_dict import HasToDict
from pyiron_workflow.snippets.singleton import Singleton
from pyiron_workflow.type_hinting import (
Expand All @@ -29,13 +29,13 @@ class ChannelConnectionError(Exception):
pass


class Channel(UsesState, HasChannel, HasToDict, ABC):
class Channel(UsesState, HasChannel, HasLabel, HasToDict, ABC):
"""
Channels facilitate the flow of information (data or control signals) into and
out of nodes.

They must have an identifier (`label: str`) and belong to a parent node
(`node: pyiron_workflow.node.Node`).
(`owner: pyiron_workflow.node.Node`).

Non-abstract channel classes should come in input/output pairs and specify the
a necessary ancestor for instances they can connect to
Expand All @@ -62,27 +62,30 @@ class Channel(UsesState, HasChannel, HasToDict, ABC):

Attributes:
label (str): The name of the channel.
node (pyiron_workflow.node.Node): The node to which the channel
belongs.
owner (pyiron_workflow.node.Node): The channel's owner.
connections (list[Channel]): Other channels to which this channel is connected.
"""

def __init__(
self,
label: str,
node: Node,
owner: Node,
):
"""
Make a new channel.

Args:
label (str): A name for the channel.
node (pyiron_workflow.node.Node): The node to which the channel belongs.
owner (pyiron_workflow.node.Node): The channel's owner.
"""
self.label: str = label
self.node: Node = node
self._label = label
self.owner: Node = owner
self.connections: list[Channel] = []

@property
def label(self) -> str:
return self._label

@abstractmethod
def __str__(self):
pass
Expand All @@ -97,8 +100,8 @@ def connection_partner_type(self) -> type[Channel]:

@property
def scoped_label(self) -> str:
"""A label combining the channel's usual label and its node's label"""
return f"{self.node.label}__{self.label}"
"""A label combining the channel's usual label and its owner's label"""
return f"{self.owner.label}__{self.label}"

def _valid_connection(self, other: Channel) -> bool:
"""
Expand Down Expand Up @@ -213,17 +216,17 @@ def to_dict(self) -> dict:
return {
"label": self.label,
"connected": self.connected,
"connections": [f"{c.node.label}.{c.label}" for c in self.connections],
"connections": [f"{c.owner.label}.{c.label}" for c in self.connections],
}

def __getstate__(self):
state = super().__getstate__()
# To avoid cyclic storage and avoid storing complex objects, purge some
# properties from the state
state["node"] = None
# It is the responsibility of the owning node to restore the node property
state["owner"] = None
# It is the responsibility of the owner to restore the owner property
state["connections"] = []
# It is the responsibility of the owning node's parent to store and restore
# It is the responsibility of the owner's parent to store and restore
# connections (if any)
return state

Expand Down Expand Up @@ -274,7 +277,7 @@ class DataChannel(Channel, ABC):
Channels with such partners pass any data updates they receive directly to this
partner (via the :attr:`value` setter).
(This is helpful for passing data between scopes, where we want input at one scope
to be passed to the input of nodes at a deeper scope, i.e. macro input passing to
to be passed to the input of owners at a deeper scope, i.e. macro input passing to
child node input, or vice versa for output.)

All these type hint tests can be disabled on the input/receiving channel
Expand Down Expand Up @@ -325,30 +328,31 @@ class DataChannel(Channel, ABC):
yet included in our test suite and behaviour is not guaranteed.

Attributes:
value: The actual data value held by the node.
value: The actual data value held by the channel.
label (str): The label for the channel.
node (pyiron_workflow.node.Node): The node to which this channel belongs.
owner (pyiron_workflow.node.Node): The channel's owner.
default (typing.Any|None): The default value to initialize to.
(Default is the singleton `NOT_DATA`.)
type_hint (typing.Any|None): A type hint for values. (Default is None.)
strict_hints (bool): Whether to check new values, connections, and partners
when this node is a value receiver. This can potentially be expensive, so
when this channel is a value receiver. This can potentially be expensive, so
consider deactivating strict hints everywhere for production runs. (Default
is True, raise exceptions when type hints get violated.)
value_receiver (pyiron_workflow.node.Node|None): Another node of the same class
whose value will always get updated when this node's value gets updated.
value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of
the same class whose value will always get updated when this channel's
value gets updated.
"""

def __init__(
self,
label: str,
node: Node,
owner: Node,
default: typing.Optional[typing.Any] = NOT_DATA,
type_hint: typing.Optional[typing.Any] = None,
strict_hints: bool = True,
value_receiver: typing.Optional[InputData] = None,
):
super().__init__(label=label, node=node)
super().__init__(label=label, owner=owner)
self._value = NOT_DATA
self._value_receiver = None
self.type_hint = type_hint
Expand Down Expand Up @@ -386,7 +390,7 @@ def value_receiver(self) -> InputData | OutputData | None:
Another data channel of the same type to whom new values are always pushed
(without type checking of any sort, not even when forming the couple!)

Useful for macros, so that the IO of owned nodes and IO at the macro level can
Useful for macros, so that the IO of children and IO at the macro level can
be kept synchronized.
"""
return self._value_receiver
Expand Down Expand Up @@ -520,7 +524,7 @@ def fetch(self) -> None:
0th connection; build graphs accordingly.

Raises:
RuntimeError: If the parent node is :attr:`running`.
RuntimeError: If the owner is :attr:`running`.
"""
for out in self.connections:
if out.value is not NOT_DATA:
Expand All @@ -533,9 +537,9 @@ def value(self):

@value.setter
def value(self, new_value):
if self.node.running:
if self.owner.running:
raise RuntimeError(
f"Parent node {self.node.label} of {self.label} is running, so value "
f"Owner {self.owner.label} of {self.label} is running, so value "
f"cannot be updated."
)
self._type_check_new_value(new_value)
Expand Down Expand Up @@ -582,12 +586,12 @@ def _node_injection(self, injection_class, *args, inject_self=True):
label = self._get_injection_label(injection_class, *args)
try:
# First check if the node already exists
return self.node.parent.children[label]
return self.owner.parent.children[label]
except (AttributeError, KeyError):
# Fall back on creating a new node in case parent is None or node nexists
node_args = (self, *args) if inject_self else args
return injection_class(
*node_args, parent=self.node.parent, label=label, run_after_init=True
*node_args, parent=self.owner.parent, label=label, run_after_init=True
)

# We don't wrap __all__ the operators, because you might really want the string or
Expand Down Expand Up @@ -767,7 +771,7 @@ class SignalChannel(Channel, ABC):
"""
Signal channels give the option control execution flow by triggering callback
functions when the channel is called.
Callbacks must be methods on the parent node that require no positional arguments.
Callbacks must be methods on the owner that require no positional arguments.
Inputs optionally accept an output signal on call, which output signals always
send when they call their input connections.

Expand Down Expand Up @@ -795,33 +799,32 @@ def connection_partner_type(self):
def __init__(
self,
label: str,
node: Node,
owner: Node,
callback: callable,
):
"""
Make a new input signal channel.

Args:
label (str): A name for the channel.
node (pyiron_workflow.node.Node): The node to which the
channel belongs.
owner (pyiron_workflow.node.Node): The channel's owner.
callback (callable): An argument-free callback to invoke when calling this
object.
object. Must be a method on the owner.
"""
super().__init__(label=label, node=node)
if self._is_node_method(callback) and self._takes_zero_arguments(callback):
super().__init__(label=label, owner=owner)
if self._is_method_on_owner(callback) and self._takes_zero_arguments(callback):
self._callback: str = callback.__name__
else:
raise BadCallbackError(
f"The channel {self.label} on {self.node.label} got an unexpected "
f"The channel {self.label} on {self.owner.label} got an unexpected "
f"callback: {callback}. "
f"Lives on node: {self._is_node_method(callback)}; "
f"Lives on owner: {self._is_method_on_owner(callback)}; "
f"take no args: {self._takes_zero_arguments(callback)} "
)

def _is_node_method(self, callback):
def _is_method_on_owner(self, callback):
try:
return callback == getattr(self.node, callback.__name__)
return callback == getattr(self.owner, callback.__name__)
except AttributeError:
return False

Expand All @@ -840,7 +843,7 @@ def _no_positional_args(func):

@property
def callback(self) -> callable:
return getattr(self.node, self._callback)
return getattr(self.owner, self._callback)

def __call__(self, other: typing.Optional[OutputSignal] = None) -> None:
self.callback()
Expand All @@ -866,10 +869,10 @@ class AccumulatingInputSignal(InputSignal):
def __init__(
self,
label: str,
node: Node,
owner: Node,
callback: callable,
):
super().__init__(label=label, node=node, callback=callback)
super().__init__(label=label, owner=owner, callback=callback)
self.received_signals: set[str] = set()

def __call__(self, other: OutputSignal) -> None:
Expand Down Expand Up @@ -915,7 +918,7 @@ def __call__(self) -> None:
def __str__(self):
return (
f"{self.label} activates "
f"{[f'{c.node.label}.{c.label}' for c in self.connections]}"
f"{[f'{c.owner.label}.{c.label}' for c in self.connections]}"
)

def __rshift__(self, other: InputSignal | Node):
Expand Down
2 changes: 1 addition & 1 deletion pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _get_connections_as_strings(
the name is protected.
"""
return [
((inp.node.label, inp.label), (out.node.label, out.label))
((inp.owner.label, inp.label), (out.owner.label, out.label))
for child in self
for inp in panel_getter(child)
for out in inp.connections
Expand Down
4 changes: 2 additions & 2 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _build_input_channels(self):
channels.append(
InputData(
label=label,
node=self,
owner=self,
default=default,
type_hint=type_hint,
)
Expand Down Expand Up @@ -497,7 +497,7 @@ def _build_output_channels(self, *return_labels: str):
channels.append(
OutputData(
label=label,
node=self,
owner=self,
type_hint=hint,
)
)
Expand Down
6 changes: 3 additions & 3 deletions pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _get_linking_channel(
"""
composite_channel = child_reference_channel.__class__(
label=composite_io_key,
node=self,
owner=self,
default=child_reference_channel.default,
type_hint=child_reference_channel.type_hint,
)
Expand Down Expand Up @@ -565,7 +565,7 @@ def _input_value_links(self):
the name is protected.
"""
return [
(c.label, (c.value_receiver.node.label, c.value_receiver.label))
(c.label, (c.value_receiver.owner.label, c.value_receiver.label))
for c in self.inputs
]

Expand All @@ -579,7 +579,7 @@ def _output_value_links(self):
the name is protected.
"""
return [
((c.node.label, c.label), c.value_receiver.label)
((c.owner.label, c.label), c.value_receiver.label)
for child in self
for c in child.outputs
if c.value_receiver is not None
Expand Down
2 changes: 1 addition & 1 deletion pyiron_workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def __setstate__(self, state):
# Channels don't store their own node in their state, so repopulate it
for io_panel in self._owned_io_panels:
for channel in io_panel:
channel.node = self
channel.owner = self

@property
def _owned_io_panels(self) -> list[IO]:
Expand Down
16 changes: 8 additions & 8 deletions pyiron_workflow/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]:
locally_scoped_dependencies = []
for upstream in channel.connections:
try:
upstream_node = nodes[upstream.node.label]
upstream_node = nodes[upstream.owner.label]
except KeyError as e:
raise KeyError(
f"The {channel.label} channel of {node.label} has a connection "
f"to {upstream.label} channel of {upstream.node.label}, but "
f"{upstream.node.label} was not found among nodes. All nodes "
f"to {upstream.label} channel of {upstream.owner.label}, but "
f"{upstream.owner.label} was not found among nodes. All nodes "
f"in the data flow dependency tree must be included."
)
if upstream_node is not upstream.node:
if upstream_node is not upstream.owner:
raise ValueError(
f"The {channel.label} channel of {node.label} has a connection "
f"to {upstream.label} channel of {upstream.node.label}, but "
f"to {upstream.label} channel of {upstream.owner.label}, but "
f"that channel's node is not the same as the nodes passed "
f"here. All nodes in the data flow dependency tree must be "
f"included."
)
locally_scoped_dependencies.append(upstream.node.label)
locally_scoped_dependencies.append(upstream.owner.label)
node_dependencies.extend(locally_scoped_dependencies)
node_dependencies = set(node_dependencies)
if node.label in node_dependencies:
Expand Down Expand Up @@ -183,7 +183,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]:

for node in nodes.values():
upstream_connections = [con for inp in node.inputs for con in inp.connections]
upstream_nodes = set([c.node for c in upstream_connections])
upstream_nodes = set([c.owner for c in upstream_connections])
upstream_rans = [n.signals.output.ran for n in upstream_nodes]
node.signals.input.accumulate_and_run.connect(*upstream_rans)
# Note: We can be super fast-and-loose here because the `nodes_to_data_digraph` call
Expand Down Expand Up @@ -227,7 +227,7 @@ def get_nodes_in_data_tree(node: Node) -> set[Node]:
nodes = set([node])
for channel in node.inputs:
for connection in channel.connections:
nodes = nodes.union(get_nodes_in_data_tree(connection.node))
nodes = nodes.union(get_nodes_in_data_tree(connection.owner))
return nodes
except RecursionError:
raise CircularDataFlowError(
Expand Down
4 changes: 2 additions & 2 deletions pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _data_connections(self) -> list[tuple[tuple[str, str], tuple[str, str]]]:
for inp_label, inp in node.inputs.items():
for conn in inp.connections:
data_connections.append(
((node.label, inp_label), (conn.node.label, conn.label))
((node.label, inp_label), (conn.owner.label, conn.label))
)
return data_connections

Expand All @@ -305,7 +305,7 @@ def _signal_connections(self) -> list[tuple[tuple[str, str], tuple[str, str]]]:
for inp_label, inp in node.signals.input.items():
for conn in inp.connections:
signal_connections.append(
((node.label, inp_label), (conn.node.label, conn.label))
((node.label, inp_label), (conn.owner.label, conn.label))
)
return signal_connections

Expand Down
Loading