Skip to content

Commit 648b557

Browse files
Adding node semantics: input type, firing type, state type; Adding lineage: previous and current state as well as objects
1 parent dfaffb4 commit 648b557

File tree

4 files changed

+164
-147
lines changed

4 files changed

+164
-147
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from sklearn.base import BaseEstimator
21
from abc import ABC, abstractmethod
2+
import uuid
3+
from enum import Enum
4+
35

46
import sklearn.base as base
5-
import uuid
7+
from sklearn.base import TransformerMixin
8+
from sklearn.base import BaseEstimator
69

710

811
class Xy:
@@ -38,10 +41,11 @@ class XYRef:
3841
computed), these holders are essential to the pipeline constructs.
3942
"""
4043

41-
def __init__(self, Xref, yref, node=None, prev_Xyrefs = None):
44+
def __init__(self, Xref, yref, prev_noderef=None, curr_noderef=None, prev_Xyrefs = None):
4245
self.__Xref__ = Xref
4346
self.__yref__ = yref
44-
self.__noderef__ = node
47+
self.__prevnoderef__ = prev_noderef
48+
self.__currnoderef__ = curr_noderef
4549
self.__prev_Xyrefs__ = prev_Xyrefs
4650

4751
def get_Xref(self):
@@ -56,29 +60,60 @@ def get_yref(self):
5660
"""
5761
return self.__yref__
5862

59-
def get_noderef(self):
60-
return self.__noderef__
63+
def get_prevnoderef(self):
64+
return self.__prevnoderef__
65+
66+
def get_currnoderef(self):
67+
return self.__currnoderef__
6168

6269
def get_prev_xyrefs(self):
6370
return self.__prev_Xyrefs__
6471

6572

73+
class NodeInputType(Enum):
74+
OR = 0,
75+
AND = 1
76+
77+
78+
class NodeFiringType(Enum):
79+
ANY = 0,
80+
ALL = 1
81+
82+
83+
class NodeStateType(Enum):
84+
STATELESS = 0,
85+
IMMUTABLE = 1,
86+
MUTABLE_SEQUENTIAL = 2,
87+
MUTABLE_AGGREGATE = 3
88+
89+
6690
class Node(ABC):
6791
"""
6892
A node class that is an abstract one, this is capturing basic info re the Node.
6993
The hash code of this node is the name of the node and equality is defined if the
7094
node name and the type of the node match.
7195
"""
96+
def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
97+
self.__node_name__ = node_name
98+
self.__node_input_type__ = node_input_type
99+
self.__node_firing_type__ = node_firing_type
100+
self.__node_state_type__ = node_state_type
101+
self.__id__ = uuid.uuid4()
72102

73103
def __str__(self):
74104
return self.__node_name__
75105

76106
def get_id(self):
77107
return self.__id__
78108

79-
@abstractmethod
80-
def get_and_flag(self):
81-
raise NotImplementedError("Please implement this method")
109+
def get_node_input_type(self):
110+
return self.__node_input_type__
111+
112+
def get_node_firing_type(self):
113+
return self.__node_firing_type__
114+
115+
def get_node_state_type(self):
116+
return self.__node_state_type__
82117

83118
@abstractmethod
84119
def clone(self):
@@ -107,7 +142,7 @@ def __eq__(self, other):
107142
)
108143

109144

110-
class OrNode(Node):
145+
class EstimatorNode(Node):
111146
"""
112147
Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
113148
stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
@@ -120,9 +155,8 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
120155
:param node_name: Name of the node
121156
:param estimator: The base estimator
122157
"""
123-
self.__node_name__ = node_name
158+
super().__init__(node_name, NodeInputType.OR, NodeFiringType.ANY, NodeStateType.IMMUTABLE)
124159
self.__estimator__ = estimator
125-
self.__id__ = uuid.uuid4()
126160

127161
def get_estimator(self) -> BaseEstimator:
128162
"""
@@ -132,41 +166,31 @@ def get_estimator(self) -> BaseEstimator:
132166
"""
133167
return self.__estimator__
134168

135-
def get_and_flag(self):
136-
"""
137-
A flag to check if node is AND or not. By definition, this is NOT
138-
an AND node.
139-
:return: False, always
140-
"""
141-
return False
142-
143169
def clone(self):
144170
cloned_estimator = base.clone(self.__estimator__)
145-
return OrNode(self.__node_name__, cloned_estimator)
171+
return EstimatorNode(self.__node_name__, cloned_estimator)
146172

147173

148-
class AndFunc(ABC):
149-
"""
150-
Or nodes are init-ed from the
151-
"""
174+
class AndTransform(TransformerMixin, BaseEstimator):
175+
@abstractmethod
176+
def transform(self, xy_list: list) -> Xy:
177+
raise NotImplementedError("Please implement this method")
178+
152179

180+
class GeneralTransform(TransformerMixin, BaseEstimator):
153181
@abstractmethod
154-
def eval(self, xy_list: list) -> Xy:
182+
def transform(self, xy: Xy) -> Xy:
155183
raise NotImplementedError("Please implement this method")
156184

157185

158186
class AndNode(Node):
159-
def __init__(self, node_name: str, and_func: AndFunc):
160-
self.__node_name__ = node_name
187+
def __init__(self, node_name: str, and_func: AndTransform):
188+
super().__init__(node_name, NodeInputType.AND, NodeFiringType.ANY, NodeStateType.STATELESS)
161189
self.__andfunc__ = and_func
162-
self.__id__ = uuid.uuid4()
163190

164-
def get_and_func(self) -> AndFunc:
191+
def get_and_func(self) -> AndTransform:
165192
return self.__andfunc__
166193

167-
def get_and_flag(self):
168-
return True
169-
170194
def clone(self):
171195
return AndNode(self.__node_name__, self.__andfunc__)
172196

codeflare/pipelines/Runtime.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import ray
22

3-
from codeflare.pipelines.Datamodel import OrNode
3+
from codeflare.pipelines.Datamodel import EstimatorNode
44
from codeflare.pipelines.Datamodel import AndNode
55
from codeflare.pipelines.Datamodel import Edge
66
from codeflare.pipelines.Datamodel import Pipeline
77
from codeflare.pipelines.Datamodel import XYRef
88
from codeflare.pipelines.Datamodel import Xy
9+
from codeflare.pipelines.Datamodel import NodeInputType
10+
from codeflare.pipelines.Datamodel import NodeStateType
11+
from codeflare.pipelines.Datamodel import NodeFiringType
912

1013
import sklearn.base as base
1114
from enum import Enum
@@ -18,43 +21,49 @@ class ExecutionType(Enum):
1821

1922

2023
@ray.remote
21-
def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, xy_ref: XYRef):
24+
def execute_or_node_remote(node: EstimatorNode, train_mode: ExecutionType, xy_ref: XYRef):
2225
estimator = node.get_estimator()
2326
# Blocking operation -- not avoidable
2427
X = ray.get(xy_ref.get_Xref())
2528
y = ray.get(xy_ref.get_yref())
2629

30+
# TODO: Can optimize the node pointers without replicating them
2731
if train_mode == ExecutionType.FIT:
2832
cloned_node = node.clone()
29-
node_ptr = ray.put(cloned_node)
33+
prev_node_ptr = ray.put(node)
3034

3135
if base.is_classifier(estimator) or base.is_regressor(estimator):
3236
# Always clone before fit, else fit is invalid
3337
cloned_estimator = cloned_node.get_estimator()
3438
cloned_estimator.fit(X, y)
39+
40+
curr_node_ptr = ray.put(cloned_node)
3541
# TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3642
res_Xref = ray.put(cloned_estimator.predict(X))
37-
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
43+
result = XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, curr_node_ptr, [xy_ref])
3844
return result
3945
else:
4046
cloned_estimator = cloned_node.get_estimator()
4147
res_Xref = ray.put(cloned_estimator.fit_transform(X, y))
42-
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
48+
curr_node_ptr = ray.put(cloned_node)
49+
result = XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, curr_node_ptr, [xy_ref])
4350
return result
4451
elif train_mode == ExecutionType.SCORE:
4552
cloned_node = node.clone()
46-
node_ptr = ray.put(cloned_node)
53+
prev_node_ptr = ray.put(node)
4754

4855
if base.is_classifier(estimator) or base.is_regressor(estimator):
4956
cloned_estimator = cloned_node.get_estimator()
5057
cloned_estimator.fit(X, y)
58+
curr_node_ptr = ray.put(cloned_node)
5159
res_Xref = ray.put(cloned_estimator.score(X, y))
52-
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
60+
result = XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, curr_node_ptr, [xy_ref])
5361
return result
5462
else:
5563
cloned_estimator = cloned_node.get_estimator()
5664
res_Xref = ray.put(cloned_estimator.fit_transform(X, y))
57-
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
65+
curr_node_ptr = ray.put(cloned_node)
66+
result = XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, curr_node_ptr, [xy_ref])
5867
return result
5968
elif train_mode == ExecutionType.PREDICT:
6069
# Test mode does not clone as it is a simple predict or transform
@@ -74,7 +83,7 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
7483
exec_xyrefs = []
7584
for xy_ref_ptr in Xyref_ptrs:
7685
xy_ref = ray.get(xy_ref_ptr)
77-
inner_result = execute_or_node_inner.remote(node, mode, xy_ref)
86+
inner_result = execute_or_node_remote.remote(node, mode, xy_ref)
7887
exec_xyrefs.append(inner_result)
7988

8089
for post_edge in post_edges:
@@ -84,21 +93,22 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
8493

8594

8695
@ray.remote
87-
def and_node_eval(node: AndNode, Xyref_list):
96+
def execute_and_node_remote(node: AndNode, Xyref_list):
8897
xy_list = []
98+
prev_node_ptr = ray.put(node)
8999
for Xyref in Xyref_list:
90100
X = ray.get(Xyref.get_Xref())
91101
y = ray.get(Xyref.get_yref())
92102
xy_list.append(Xy(X, y))
93103

94104
cloned_node = node.clone()
95-
node_ptr = ray.put(cloned_node)
105+
curr_node_ptr = ray.put(cloned_node)
96106

97107
cloned_and_func = cloned_node.get_and_func()
98-
res_Xy = cloned_and_func.eval(xy_list)
108+
res_Xy = cloned_and_func.transform(xy_list)
99109
res_Xref = ray.put(res_Xy.get_x())
100110
res_yref = ray.put(res_Xy.get_y())
101-
return XYRef(res_Xref, res_yref, node_ptr, Xyref_list)
111+
return XYRef(res_Xref, res_yref, prev_node_ptr, curr_node_ptr, Xyref_list)
102112

103113

104114
def execute_and_node_inner(node: AndNode, Xyref_ptrs):
@@ -109,7 +119,7 @@ def execute_and_node_inner(node: AndNode, Xyref_ptrs):
109119
Xyref = ray.get(Xyref_ptr)
110120
Xyref_list.append(Xyref)
111121

112-
Xyref_ptr = and_node_eval.remote(node, Xyref_list)
122+
Xyref_ptr = execute_and_node_remote.remote(node, Xyref_list)
113123
result.append(Xyref_ptr)
114124
return result
115125

@@ -145,9 +155,9 @@ def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
145155
for node in nodes:
146156
pre_edges = pipeline.get_pre_edges(node)
147157
post_edges = pipeline.get_post_edges(node)
148-
if not node.get_and_flag():
158+
if node.get_node_input_type() == NodeInputType.OR:
149159
execute_or_node(node, pre_edges, edge_args, post_edges, mode)
150-
elif node.get_and_flag():
160+
elif node.get_node_input_type() == NodeInputType.AND:
151161
execute_and_node(node, pre_edges, edge_args, post_edges)
152162

153163
out_args = {}

0 commit comments

Comments
 (0)