diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 7c0d3e90f..a95f3806e 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -14,6 +14,7 @@ from pyiron_snippets.singleton import Singleton +from pyiron_workflow.compatibility import Self from pyiron_workflow.mixin.display_state import HasStateDisplay from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel from pyiron_workflow.type_hinting import ( @@ -25,11 +26,24 @@ from pyiron_workflow.io import HasIO -class ChannelConnectionError(Exception): +class ChannelError(Exception): pass -class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): +class ChannelConnectionError(ChannelError): + pass + + +ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel") + + +class Channel( + HasChannel, + HasLabel, + HasStateDisplay, + typing.Generic[ConnectionPartner], + ABC +): """ Channels facilitate the flow of information (data or control signals) into and out of :class:`HasIO` objects (namely nodes). @@ -37,12 +51,9 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): They must have an identifier (`label: str`) and belong to an `owner: pyiron_workflow.io.HasIO`. - Non-abstract channel classes should come in input/output pairs and specify the - a necessary ancestor for instances they can connect to - (`connection_partner_type: type[Channel]`). Channels may form (:meth:`connect`/:meth:`disconnect`) and store - (:attr:`connections: list[Channel]`) connections with other channels. + (:attr:`connections`) connections with other channels. This connection information is reflexive, and is duplicated to be stored on _both_ channels in the form of a reference to their counterpart in the connection. @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirement is that the connecting channels form a - "conjugate pair" of classes, i.e. they are children of each other's partner class - (:attr:`connection_partner_type: type[Channel]`) -- input/output connects to - output/input. + In this abstract class the only requirements are that the connecting channels form + a "conjugate pair" of classes, i.e. they are children of each other's partner class + and thus have the same "flavor", but are an input/output pair; and that they define + a string representation. Iterating over channels yields their connections. @@ -80,7 +91,7 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[Channel] = [] + self.connections: list[ConnectionPartner] = [] @property def label(self) -> str: @@ -90,12 +101,12 @@ def label(self) -> str: def __str__(self): pass - @property + @classmethod @abstractmethod - def connection_partner_type(self) -> type[Channel]: + def connection_partner_type(cls) -> type[ConnectionPartner]: """ - Input and output class pairs must specify a parent class for their valid - connection partners. + The class forming a conjugate pair with this channel class -- i.e. the same + "flavor" of channel, but opposite in I/O. """ @property @@ -108,21 +119,18 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - def _valid_connection(self, other: Channel) -> bool: + @abstractmethod + def _valid_connection(self, other: object) -> bool: """ Logic for determining if a connection is valid. - - Connections only allowed to instances with the right parent type -- i.e. - connection pairs should be an input/output. """ - return isinstance(other, self.connection_partner_type) - def connect(self, *others: Channel) -> None: + def connect(self, *others: ConnectionPartner) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :attr:`connection_partner_type`. + :meth:`connection_partner_type()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -145,24 +153,28 @@ def connect(self, *others: Channel) -> None: self.connections.insert(0, other) other.connections.insert(0, self) else: - if isinstance(other, self.connection_partner_type): + if isinstance(other, self.connection_partner_type()): raise ChannelConnectionError( - f"The channel {other.full_label} ({other.__class__.__name__}" - f") has the correct type " - f"({self.connection_partner_type.__name__}) to connect with " - f"{self.full_label} ({self.__class__.__name__}), but is not " - f"a valid connection. Please check type hints, etc." - f"{other.full_label}.type_hint = {other.type_hint}; " - f"{self.full_label}.type_hint = {self.type_hint}" + self._connection_partner_failure_message(other) ) from None else: raise TypeError( - f"Can only connect to {self.connection_partner_type.__name__} " - f"objects, but {self.full_label} ({self.__class__.__name__}) " + f"Can only connect to {self.connection_partner_type()} " + f"objects, but {self.full_label} ({self.__class__}) " f"got {other} ({type(other)})" ) - def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: + def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: + return ( + f"The channel {other.full_label} ({other.__class__}) has the " + f"correct type ({self.connection_partner_type()}) to connect with " + f"{self.full_label} ({self.__class__}), but is not a valid " + f"connection." + ) + + def disconnect( + self, *others: ConnectionPartner + ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -182,7 +194,9 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Channel, Channel]]: + def disconnect_all( + self + ) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -257,8 +271,9 @@ def __bool__(self): NOT_DATA = NotData() +DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") -class DataChannel(Channel, ABC): +class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -331,7 +346,7 @@ class DataChannel(Channel, ABC): when this channel is a value receiver. This can potentially be expensive, so consider deactivating strict hints everywhere for production runs. (Default is True, raise exceptions when type hints get violated.) - value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of + value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of the same class whose value will always get updated when this channel's value gets updated. """ @@ -343,7 +358,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: InputData | None = None, + value_receiver: Self | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -352,7 +367,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver = value_receiver + self.value_receiver: Self = value_receiver @property def value(self): @@ -379,7 +394,7 @@ def _type_check_new_value(self, new_value): ) @property - def value_receiver(self) -> InputData | OutputData | None: + def value_receiver(self) -> Self | None: """ Another data channel of the same type to whom new values are always pushed (without type checking of any sort, not even when forming the couple!) @@ -390,7 +405,7 @@ def value_receiver(self) -> InputData | OutputData | None: return self._value_receiver @value_receiver.setter - def value_receiver(self, new_partner: InputData | OutputData | None): + def value_receiver(self, new_partner: Self | None): if new_partner is not None: if not isinstance(new_partner, self.__class__): raise TypeError( @@ -445,8 +460,8 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: DataChannel) -> bool: - if super()._valid_connection(other): + def _valid_connection(self, other: object) -> bool: + if isinstance(other, self.connection_partner_type()): if self._both_typed(other): out, inp = self._figure_out_who_is_who(other) if not inp.strict_hints: @@ -461,13 +476,32 @@ def _valid_connection(self, other: DataChannel) -> bool: else: return False - def _both_typed(self, other: DataChannel) -> bool: + def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str: + msg = super()._connection_partner_failure_message(other) + msg += ( + f"Please check type hints, etc. {other.full_label}.type_hint = " + f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" + ) + return msg + + def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataChannel + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: - return (self, other) if isinstance(self, OutputData) else (other, self) + if isinstance(self, InputData) and isinstance(other, OutputData): + return other, self + elif isinstance(self, OutputData) and isinstance(other, InputData): + return self, other + else: + raise ChannelError( + f"This should be unreachable; data channel conjugate pairs should " + f"always be input/output, but got {type(self)} for {self.full_label} " + f"and {type(other)} for {other.full_label}. If you don't believe you " + f"are responsible for this error, please contact the maintainers via " + f"GitHub." + ) def __str__(self): return str(self.value) @@ -491,9 +525,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel): - @property - def connection_partner_type(self): +class InputData(DataChannel["OutputData"]): + + @classmethod + def connection_partner_type(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -530,13 +565,17 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel): - @property - def connection_partner_type(self): +class OutputData(DataChannel["InputData"]): + @classmethod + def connection_partner_type(cls) -> type[InputData]: return InputData -class SignalChannel(Channel, ABC): +SignalConnectionPartner = typing.TypeVar( + "SignalConnectionPartner", bound="SignalChannel" +) + +class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -555,15 +594,15 @@ class SignalChannel(Channel, ABC): def __call__(self) -> None: pass + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) + class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel): - @property - def connection_partner_type(self): - return OutputSignal +class InputSignal(SignalChannel["OutputSignal"]): def __init__( self, @@ -591,6 +630,10 @@ def __init__( f"all args are optional: {self._all_args_arg_optional(callback)} " ) + @classmethod + def connection_partner_type(cls) -> type[OutputSignal]: + return OutputSignal + def _is_method_on_owner(self, callback): try: return callback == getattr(self.owner, callback.__name__) @@ -644,14 +687,15 @@ def __init__( super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() - def __call__(self, other: OutputSignal) -> None: + def __call__(self, other: OutputSignal | None = None) -> None: """ Fire callback iff you have received at least one signal from each of your current connections. Resets the collection of received signals when firing. """ - self.received_signals.update([other.scoped_label]) + if isinstance(other, OutputSignal): + self.received_signals.update([other.scoped_label]) if ( len( set(c.scoped_label for c in self.connections).difference( @@ -675,9 +719,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel): - @property - def connection_partner_type(self): +class OutputSignal(SignalChannel["InputSignal"]): + + @classmethod + def connection_partner_type(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index eaeb4a855..151a97faa 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pyiron_workflow.channels import ( @@ -6,6 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, + ConnectionPartner, InputData, InputSignal, OutputData, @@ -30,25 +33,24 @@ def data_input_locked(self): return self.locked -class InputChannel(Channel): +class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" - def __str__(self): return "non-abstract input" - @property - def connection_partner_type(self) -> type[Channel]: - return OutputChannel + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) -class OutputChannel(Channel): - """Just to de-abstract the base class""" +class InputChannel(DummyChannel["OutputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[OutputChannel]: + return OutputChannel - def __str__(self): - return "non-abstract output" - @property - def connection_partner_type(self) -> type[Channel]: +class OutputChannel(DummyChannel["InputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[InputChannel]: return InputChannel @@ -389,26 +391,44 @@ def test_aggregating_call(self): owner = DummyOwner() agg = AccumulatingInputSignal(label="agg", owner=owner, callback=owner.update) - with self.assertRaises( - TypeError, - msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional", - ): - agg() - out2 = OutputSignal(label="out2", owner=DummyOwner()) agg.connect(self.out, out2) + out_unrelated = OutputSignal(label="out_unrelated", owner=DummyOwner()) + + signals_sent = 0 self.assertEqual( 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, len(agg.received_signals), msg="Sanity check on initial conditions" + signals_sent, + len(agg.received_signals), + msg="Sanity check on initial conditions" ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") + agg() + signals_sent += 0 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + agg(out_unrelated) + signals_sent += 1 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + self.out() - self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") + signals_sent += 1 + self.assertEqual( + signals_sent, + len(agg.received_signals), + msg="Signals from other channels should be received" + ) self.assertListEqual( [0], owner.foo, @@ -416,8 +436,9 @@ def test_aggregating_call(self): ) self.out() + signals_sent += 0 self.assertEqual( - 1, + signals_sent, len(agg.received_signals), msg="Repeatedly receiving the same signal should have no effect", )