-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
306 lines (237 loc) · 8.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from __future__ import annotations
import contextlib
import dataclasses
import typing
from typing import Any
from uuid import UUID
from znflow import exceptions
@contextlib.contextmanager
def disable_graph(*args, **kwargs):
"""Temporarily disable set the graph to empty.
This can be useful, if you e.g. want to use 'get_attribute'.
"""
graph = get_graph()
set_graph(empty)
try:
yield
finally:
set_graph(graph)
class Property:
"""Custom Property with disabled graph.
References
----------
Adapted from https://docs.python.org/3/howto/descriptor.html#properties
"""
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
self.fget = disable_graph()(fget)
self.fset = disable_graph()(fset)
self.fdel = disable_graph()(fdel)
if doc is None and fget is not None:
doc = fget.__doc__
self.__doc__ = doc
self._name = ""
def __set_name__(self, owner, name):
self._name = name
def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.fget is None:
raise AttributeError(f"property '{self._name}' has no getter")
return self.fget(obj)
def __set__(self, obj, value):
if self.fset is None:
raise AttributeError(f"property '{self._name}' has no setter")
self.fset(obj, value)
def __delete__(self, obj):
if self.fdel is None:
raise AttributeError(f"property '{self._name}' has no deleter")
self.fdel(obj)
def getter(self, fget):
prop = type(self)(fget, self.fset, self.fdel, self.__doc__)
prop._name = self._name
return prop
def setter(self, fset):
prop = type(self)(self.fget, fset, self.fdel, self.__doc__)
prop._name = self._name
return prop
def deleter(self, fdel):
prop = type(self)(self.fget, self.fset, fdel, self.__doc__)
prop._name = self._name
return prop
empty = object()
class NodeBaseMixin:
"""A Parent for all Nodes.
This class is used to globally access and change all classes that inherit from it.
Attributes
----------
_graph_ : DiGraph
uuid : UUID
"""
_graph_ = empty
_external_ = False
_uuid: UUID = None
_protected_ = [
"_graph_",
"uuid",
"_uuid",
] # TODO consider adding regex patterns
@property
def uuid(self):
return self._uuid
@uuid.setter
def uuid(self, value):
if self._uuid is not None:
raise ValueError("uuid is already set")
self._uuid = value
def run(self):
raise NotImplementedError
def get_graph():
return NodeBaseMixin._graph_
def set_graph(value):
NodeBaseMixin._graph_ = value
_get_attribute_none = object()
def get_attribute(obj, name, default=_get_attribute_none):
"""Get the real value of the attribute and not a znflow.Connection."""
with disable_graph():
if default is _get_attribute_none:
return getattr(obj, name)
return getattr(obj, name, default)
@dataclasses.dataclass(frozen=True)
class Connection:
"""A Connector for Nodes.
instance: either a Node or FunctionFuture
attribute:
Node.attribute
or FunctionFuture.result
or None if the class is passed and not an attribute
"""
instance: any
attribute: any
item: any = None
def __post_init__(self):
if self.attribute is not None and self.attribute.startswith("_"):
raise ValueError("Private attributes are not allowed.")
def __getitem__(self, item):
return dataclasses.replace(self, instance=self, attribute=None, item=item)
def __iter__(self):
raise TypeError(f"Can not iterate over {self}.")
def __add__(
self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
) -> CombinedConnections:
if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
return CombinedConnections(connections=[self, other])
raise TypeError(f"Can not add {type(other)} to {type(self)}.")
def __radd__(self, other):
"""Enable 'sum([a, b], [])'"""
return self if other == [] else self.__add__(other)
@property
def uuid(self):
return self.instance.uuid
@property
def _external_(self):
return self.instance._external_
@property
def result(self):
if self.attribute:
result = getattr(self.instance, self.attribute)
elif isinstance(self.instance, (FunctionFuture, self.__class__)):
result = self.instance.result
else:
result = self.instance
return result[self.item] if self.item else result
def __getattribute__(self, __name: str) -> Any:
try:
return super().__getattribute__(__name)
except AttributeError as e:
raise exceptions.ConnectionAttributeError(
"Connection does not support further attributes to its result."
) from e
@dataclasses.dataclass(frozen=True)
class CombinedConnections:
"""Combine multiple Connections into one.
This class allows to 'add' Connections and/or FunctionFutures.
This only works if the Connection or FunctionFuture points to a 'list'.
A new entry of 'CombinedConnections' will be created for every time a new
item is added.
Examples
--------
>>> import znflow
>>> @znflow.nodfiy
>>> def add(size) -> list:
>>> return list(range(size))
>>> with znflow.DiGraph() as graph:
>>> outs = add(2) + add(3)
>>> graph.run()
>>> assert outs.result == [0, 1, 0, 1, 2]
Attributes
----------
connections : list[Connection|FunctionFuture|AddedConnections]
The List of items to be added.
item : any
Any slice to be applied to the result.
"""
connections: typing.List[Connection]
item: any = None
def __add__(
self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
) -> CombinedConnections:
"""Implement add for AddedConnections.
Raises
------
ValueError
If self.item is set, we can not add another item.
TypeError
If other is not a Connection, FunctionFuture or AddedConnections.
"""
if self.item is not None:
raise ValueError("Can not combine multiple slices")
if isinstance(other, (Connection, FunctionFuture)):
return dataclasses.replace(self, connections=self.connections + [other])
elif isinstance(other, CombinedConnections):
return dataclasses.replace(
self, connections=self.connections + other.connections
)
else:
raise TypeError(f"Can not add {type(other)} to {type(self)}.")
def __radd__(self, other):
"""Enable 'sum([a, b], [])'"""
return self if other == [] else self.__add__(other)
def __getitem__(self, item):
return dataclasses.replace(self, item=item)
def __iter__(self):
raise TypeError(f"Can not iterate over {self}.")
@property
def result(self):
try:
results = []
for connection in self.connections:
results.extend(connection.result)
return results[self.item] if self.item else results
except TypeError as err:
raise TypeError(
f"The value {connection.result} is of type {type(connection.result)}. The"
f" only supported type is list. Please change {connection}"
) from err
@dataclasses.dataclass
class FunctionFuture(NodeBaseMixin):
function: typing.Callable
args: typing.Tuple
kwargs: typing.Dict
item: any = None
result: any = dataclasses.field(default=None, init=False, repr=True)
_protected_ = NodeBaseMixin._protected_ + ["function", "args", "kwargs"]
def run(self):
self.result = self.function(*self.args, **self.kwargs)
def __getitem__(self, item):
return Connection(instance=self, attribute=None, item=item)
def __iter__(self):
raise TypeError(f"Can not iterate over {self}.")
def __add__(
self, other: typing.Union[Connection, FunctionFuture, CombinedConnections]
) -> CombinedConnections:
if isinstance(other, (Connection, FunctionFuture, CombinedConnections)):
return CombinedConnections(connections=[self, other])
raise TypeError(f"Can not add {type(other)} to {type(self)}.")
def __radd__(self, other):
"""Enable 'sum([a, b], [])'"""
return self if other == [] else self.__add__(other)