This repository has been archived by the owner on Jul 3, 2023. It is now read-only.
/
node.py
147 lines (117 loc) · 4.94 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import copy
import inspect
import logging
from enum import Enum
from typing import Type, Dict, Any, Callable, List
logger = logging.getLogger(__name__)
"""
Module that contains the primitive components of the graph.
These get their own file because we don't like circular dependencies.
"""
class NodeSource(Enum):
"""Specifies where this node's value originates. This can be used by different executors to flexibly execute a function graph."""
STANDARD = 1 # standard dependencies
EXTERNAL = 2 # This node's value should be taken from cache
PRIOR_RUN = 3 # This node's value sould be taken from a prior run. This is not used in a standard function graph, but it comes in handy for repeatedly running the same one.
class DependencyType(Enum):
REQUIRED = 1
OPTIONAL = 2
@staticmethod
def from_parameter(param: inspect.Parameter):
if param.default == inspect.Parameter.empty:
return DependencyType.REQUIRED
return DependencyType.OPTIONAL
class Node(object):
"""Object representing a node of computation."""
def __init__(self,
name: str,
typ: Type,
doc_string: str = '',
callabl: Callable = None,
node_source: NodeSource = NodeSource.STANDARD,
input_types: Dict[str, Type] = None):
"""Constructor for our Node object.
:param name: the name of the function.
:param typ: the output type of the function.
:param doc_string: the doc string for the function. Optional.
:param callabl: the actual function callable.
:param node_source: whether this is something someone has to pass in.
:param input_types: the input parameters and their types.
"""
self._name = name
self._type = typ
if typ is None or typ == inspect._empty:
raise ValueError(f'Missing type for hint for function {name}. Please add one to fix.')
self._callable = callabl
self._doc = doc_string
self._node_source = node_source
self._dependencies = []
self._depended_on_by = []
if self._node_source == NodeSource.STANDARD:
if input_types is not None:
self._input_types = {key: (value, DependencyType.REQUIRED) for key, value in input_types.items()}
else:
signature = inspect.signature(callabl)
self._input_types = {}
for key, value in signature.parameters.items():
if value.annotation == inspect._empty:
raise ValueError(f'Missing type hint for {key} in function {name}. Please add one to fix.')
self._input_types[key] = (value.annotation, DependencyType.from_parameter(value))
@property
def documentation(self) -> str:
return self._doc
@property
def input_types(self) -> Dict[str, Type]:
return self._input_types
@property
def name(self) -> str:
return self._name
@property
def type(self) -> Any:
return self._type
@property
def callable(self):
return self._callable
# TODO - deprecate in favor of the node sources above
@property
def user_defined(self):
return self._node_source == NodeSource.EXTERNAL
@property
def node_source(self):
return self._node_source
@property
def dependencies(self) -> List['Node']:
return self._dependencies
@property
def depended_on_by(self) -> List['Node']:
return self._depended_on_by
def __hash__(self):
return hash(self._name)
def __repr__(self):
return f'<{self._name}>'
def __eq__(self, other: 'Node'):
"""Want to deeply compare nodes in a custom way.
Current user is just unit tests. But you never know :)
Note: we only compare names of dependencies because we don't want infinite recursion.
"""
return (isinstance(other, Node) and
self._name == other.name and
self._type == other.type and
self._doc == other.documentation and
self.user_defined == other.user_defined and
[n.name for n in self.dependencies] == [o.name for o in other.dependencies] and
[n.name for n in self.depended_on_by] == [o.name for o in other.depended_on_by] and
self.node_source == other.node_source)
def __ne__(self, other: 'Node'):
return not self.__eq__(other)
@staticmethod
def from_fn(fn: Callable, name: str = None) -> 'Node':
"""Generates a node from a function. Optionally overrides the name.
:param fn: Function to generate the name from
:param name: Name to use for the node
:return: The node we generated
"""
if name is None:
name = fn.__name__
sig = inspect.signature(fn)
return Node(name, sig.return_annotation, fn.__doc__ if fn.__doc__ else '', callabl=fn)