Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 106 additions & 61 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pyiron_snippets.singleton import Singleton

from pyiron_workflow.compatibility import Self
from pyiron_workflow.mixin.display_state import HasStateDisplay
from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel
from pyiron_workflow.type_hinting import (
Expand All @@ -25,24 +26,34 @@
from pyiron_workflow.io import HasIO


class ChannelConnectionError(Exception):
class ChannelError(Exception):
pass


class Channel(HasChannel, HasLabel, HasStateDisplay, ABC):
class ChannelConnectionError(ChannelError):
pass


ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel")


class Channel(
HasChannel,
HasLabel,
HasStateDisplay,
typing.Generic[ConnectionPartner],
ABC
):
"""
Channels facilitate the flow of information (data or control signals) into and
out of :class:`HasIO` objects (namely nodes).

They must have an identifier (`label: str`) and belong to an
`owner: pyiron_workflow.io.HasIO`.

Non-abstract channel classes should come in input/output pairs and specify the
a necessary ancestor for instances they can connect to
(`connection_partner_type: type[Channel]`).

Channels may form (:meth:`connect`/:meth:`disconnect`) and store
(:attr:`connections: list[Channel]`) connections with other channels.
(:attr:`connections`) connections with other channels.

This connection information is reflexive, and is duplicated to be stored on _both_
channels in the form of a reference to their counterpart in the connection.
Expand All @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC):
these (dis)connections is guaranteed to be handled, and new connections are
subjected to a validity test.

In this abstract class the only requirement is that the connecting channels form a
"conjugate pair" of classes, i.e. they are children of each other's partner class
(:attr:`connection_partner_type: type[Channel]`) -- input/output connects to
output/input.
In this abstract class the only requirements are that the connecting channels form
a "conjugate pair" of classes, i.e. they are children of each other's partner class
and thus have the same "flavor", but are an input/output pair; and that they define
a string representation.

Iterating over channels yields their connections.

Expand All @@ -80,7 +91,7 @@ def __init__(
"""
self._label = label
self.owner: HasIO = owner
self.connections: list[Channel] = []
self.connections: list[ConnectionPartner] = []

@property
def label(self) -> str:
Expand All @@ -90,12 +101,12 @@ def label(self) -> str:
def __str__(self):
pass

@property
@classmethod
@abstractmethod
def connection_partner_type(self) -> type[Channel]:
def connection_partner_type(cls) -> type[ConnectionPartner]:
"""
Input and output class pairs must specify a parent class for their valid
connection partners.
The class forming a conjugate pair with this channel class -- i.e. the same
"flavor" of channel, but opposite in I/O.
"""

@property
Expand All @@ -108,21 +119,18 @@ def full_label(self) -> str:
"""A label combining the channel's usual label and its owner's semantic path"""
return f"{self.owner.full_label}.{self.label}"

def _valid_connection(self, other: Channel) -> bool:
@abstractmethod
def _valid_connection(self, other: object) -> bool:
"""
Logic for determining if a connection is valid.

Connections only allowed to instances with the right parent type -- i.e.
connection pairs should be an input/output.
"""
return isinstance(other, self.connection_partner_type)

def connect(self, *others: Channel) -> None:
def connect(self, *others: ConnectionPartner) -> None:
"""
Form a connection between this and one or more other channels.
Connections are reflexive, and should only occur between input and output
channels, i.e. they are instances of each others
:attr:`connection_partner_type`.
:meth:`connection_partner_type()`.

New connections get _prepended_ to the connection lists, so they appear first
when searching over connections.
Expand All @@ -145,24 +153,28 @@ def connect(self, *others: Channel) -> None:
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
if isinstance(other, self.connection_partner_type):
if isinstance(other, self.connection_partner_type()):
raise ChannelConnectionError(
f"The channel {other.full_label} ({other.__class__.__name__}"
f") has the correct type "
f"({self.connection_partner_type.__name__}) to connect with "
f"{self.full_label} ({self.__class__.__name__}), but is not "
f"a valid connection. Please check type hints, etc."
f"{other.full_label}.type_hint = {other.type_hint}; "
f"{self.full_label}.type_hint = {self.type_hint}"
self._connection_partner_failure_message(other)
) from None
else:
raise TypeError(
f"Can only connect to {self.connection_partner_type.__name__} "
f"objects, but {self.full_label} ({self.__class__.__name__}) "
f"Can only connect to {self.connection_partner_type()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)

def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]:
def _connection_partner_failure_message(self, other: ConnectionPartner) -> str:
return (
f"The channel {other.full_label} ({other.__class__}) has the "
f"correct type ({self.connection_partner_type()}) to connect with "
f"{self.full_label} ({self.__class__}), but is not a valid "
f"connection."
)

def disconnect(
self, *others: ConnectionPartner
) -> list[tuple[Self, ConnectionPartner]]:
"""
If currently connected to any others, removes this and the other from eachothers
respective connections lists.
Expand All @@ -182,7 +194,9 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]:
destroyed_connections.append((self, other))
return destroyed_connections

def disconnect_all(self) -> list[tuple[Channel, Channel]]:
def disconnect_all(
self
) -> list[tuple[Self, ConnectionPartner]]:
"""
Disconnect from all other channels currently in the connections list.
"""
Expand Down Expand Up @@ -257,8 +271,9 @@ def __bool__(self):

NOT_DATA = NotData()

DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel")

class DataChannel(Channel, ABC):
class DataChannel(Channel[DataConnectionPartner], ABC):
"""
Data channels control the flow of data on the graph.

Expand Down Expand Up @@ -331,7 +346,7 @@ class DataChannel(Channel, ABC):
when this channel is a value receiver. This can potentially be expensive, so
consider deactivating strict hints everywhere for production runs. (Default
is True, raise exceptions when type hints get violated.)
value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of
value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of
the same class whose value will always get updated when this channel's
value gets updated.
"""
Expand All @@ -343,7 +358,7 @@ def __init__(
default: typing.Any | None = NOT_DATA,
type_hint: typing.Any | None = None,
strict_hints: bool = True,
value_receiver: InputData | None = None,
value_receiver: Self | None = None,
):
super().__init__(label=label, owner=owner)
self._value = NOT_DATA
Expand All @@ -352,7 +367,7 @@ def __init__(
self.strict_hints = strict_hints
self.default = default
self.value = default # Implicitly type check your default by assignment
self.value_receiver = value_receiver
self.value_receiver: Self = value_receiver

@property
def value(self):
Expand All @@ -379,7 +394,7 @@ def _type_check_new_value(self, new_value):
)

@property
def value_receiver(self) -> InputData | OutputData | None:
def value_receiver(self) -> Self | None:
"""
Another data channel of the same type to whom new values are always pushed
(without type checking of any sort, not even when forming the couple!)
Expand All @@ -390,7 +405,7 @@ def value_receiver(self) -> InputData | OutputData | None:
return self._value_receiver

@value_receiver.setter
def value_receiver(self, new_partner: InputData | OutputData | None):
def value_receiver(self, new_partner: Self | None):
if new_partner is not None:
if not isinstance(new_partner, self.__class__):
raise TypeError(
Expand Down Expand Up @@ -445,8 +460,8 @@ def _value_is_data(self) -> bool:
def _has_hint(self) -> bool:
return self.type_hint is not None

def _valid_connection(self, other: DataChannel) -> bool:
if super()._valid_connection(other):
def _valid_connection(self, other: object) -> bool:
if isinstance(other, self.connection_partner_type()):
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
Expand All @@ -461,13 +476,32 @@ def _valid_connection(self, other: DataChannel) -> bool:
else:
return False

def _both_typed(self, other: DataChannel) -> bool:
def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str:
msg = super()._connection_partner_failure_message(other)
msg += (
f"Please check type hints, etc. {other.full_label}.type_hint = "
f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}"
)
return msg

def _both_typed(self, other: DataConnectionPartner | Self) -> bool:
return self._has_hint and other._has_hint

def _figure_out_who_is_who(
self, other: DataChannel
self, other: DataConnectionPartner
) -> tuple[OutputData, InputData]:
return (self, other) if isinstance(self, OutputData) else (other, self)
if isinstance(self, InputData) and isinstance(other, OutputData):
return other, self
elif isinstance(self, OutputData) and isinstance(other, InputData):
return self, other
else:
raise ChannelError(
f"This should be unreachable; data channel conjugate pairs should "
f"always be input/output, but got {type(self)} for {self.full_label} "
f"and {type(other)} for {other.full_label}. If you don't believe you "
f"are responsible for this error, please contact the maintainers via "
f"GitHub."
)

def __str__(self):
return str(self.value)
Expand All @@ -491,9 +525,10 @@ def display_state(self, state=None, ignore_private=True):
return super().display_state(state=state, ignore_private=ignore_private)


class InputData(DataChannel):
@property
def connection_partner_type(self):
class InputData(DataChannel["OutputData"]):

@classmethod
def connection_partner_type(cls) -> type[OutputData]:
return OutputData

def fetch(self) -> None:
Expand Down Expand Up @@ -530,13 +565,17 @@ def value(self, new_value):
self._value = new_value


class OutputData(DataChannel):
@property
def connection_partner_type(self):
class OutputData(DataChannel["InputData"]):
@classmethod
def connection_partner_type(cls) -> type[InputData]:
return InputData


class SignalChannel(Channel, ABC):
SignalConnectionPartner = typing.TypeVar(
"SignalConnectionPartner", bound="SignalChannel"
)

class SignalChannel(Channel[SignalConnectionPartner], ABC):
"""
Signal channels give the option control execution flow by triggering callback
functions when the channel is called.
Expand All @@ -555,15 +594,15 @@ class SignalChannel(Channel, ABC):
def __call__(self) -> None:
pass

def _valid_connection(self, other: object) -> bool:
return isinstance(other, self.connection_partner_type())


class BadCallbackError(ValueError):
pass


class InputSignal(SignalChannel):
@property
def connection_partner_type(self):
return OutputSignal
class InputSignal(SignalChannel["OutputSignal"]):

def __init__(
self,
Expand Down Expand Up @@ -591,6 +630,10 @@ def __init__(
f"all args are optional: {self._all_args_arg_optional(callback)} "
)

@classmethod
def connection_partner_type(cls) -> type[OutputSignal]:
return OutputSignal

def _is_method_on_owner(self, callback):
try:
return callback == getattr(self.owner, callback.__name__)
Expand Down Expand Up @@ -644,14 +687,15 @@ def __init__(
super().__init__(label=label, owner=owner, callback=callback)
self.received_signals: set[str] = set()

def __call__(self, other: OutputSignal) -> None:
def __call__(self, other: OutputSignal | None = None) -> None:
"""
Fire callback iff you have received at least one signal from each of your
current connections.

Resets the collection of received signals when firing.
"""
self.received_signals.update([other.scoped_label])
if isinstance(other, OutputSignal):
self.received_signals.update([other.scoped_label])
if (
len(
set(c.scoped_label for c in self.connections).difference(
Expand All @@ -675,9 +719,10 @@ def __lshift__(self, others):
other._connect_accumulating_input_signal(self)


class OutputSignal(SignalChannel):
@property
def connection_partner_type(self):
class OutputSignal(SignalChannel["InputSignal"]):

@classmethod
def connection_partner_type(cls) -> type[InputSignal]:
return InputSignal

def __call__(self) -> None:
Expand Down
Loading
Loading