Skip to content

Commit dfaffb4

Browse files
Adding lineage first commit, needs a documentation/ADR
1 parent 508f716 commit dfaffb4

File tree

5 files changed

+226
-203
lines changed

5 files changed

+226
-203
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from sklearn.base import BaseEstimator
22
from abc import ABC, abstractmethod
33

4+
import sklearn.base as base
5+
import uuid
6+
47

58
class Xy:
69
"""
@@ -35,9 +38,11 @@ class XYRef:
3538
computed), these holders are essential to the pipeline constructs.
3639
"""
3740

38-
def __init__(self, Xref, yref):
41+
def __init__(self, Xref, yref, node=None, prev_Xyrefs = None):
3942
self.__Xref__ = Xref
4043
self.__yref__ = yref
44+
self.__noderef__ = node
45+
self.__prev_Xyrefs__ = prev_Xyrefs
4146

4247
def get_Xref(self):
4348
"""
@@ -51,6 +56,12 @@ def get_yref(self):
5156
"""
5257
return self.__yref__
5358

59+
def get_noderef(self):
60+
return self.__noderef__
61+
62+
def get_prev_xyrefs(self):
63+
return self.__prev_Xyrefs__
64+
5465

5566
class Node(ABC):
5667
"""
@@ -62,17 +73,24 @@ class Node(ABC):
6273
def __str__(self):
6374
return self.__node_name__
6475

76+
def get_id(self):
77+
return self.__id__
78+
6579
@abstractmethod
6680
def get_and_flag(self):
6781
raise NotImplementedError("Please implement this method")
6882

83+
@abstractmethod
84+
def clone(self):
85+
raise NotImplementedError("Please implement the clone method")
86+
6987
def __hash__(self):
7088
"""
7189
Hash code, defined as the hash code of the node name
7290
7391
:return: Hash code
7492
"""
75-
return self.__node_name__.__hash__()
93+
return self.__id__.__hash__()
7694

7795
def __eq__(self, other):
7896
"""
@@ -84,6 +102,7 @@ def __eq__(self, other):
84102
"""
85103
return (
86104
self.__class__ == other.__class__ and
105+
self.__id__ == other.__id__ and
87106
self.__node_name__ == other.__node_name__
88107
)
89108

@@ -93,7 +112,6 @@ class OrNode(Node):
93112
Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
94113
stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
95114
"""
96-
__estimator__ = None
97115

98116
def __init__(self, node_name: str, estimator: BaseEstimator):
99117
"""
@@ -104,6 +122,7 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
104122
"""
105123
self.__node_name__ = node_name
106124
self.__estimator__ = estimator
125+
self.__id__ = uuid.uuid4()
107126

108127
def get_estimator(self) -> BaseEstimator:
109128
"""
@@ -121,6 +140,10 @@ def get_and_flag(self):
121140
"""
122141
return False
123142

143+
def clone(self):
144+
cloned_estimator = base.clone(self.__estimator__)
145+
return OrNode(self.__node_name__, cloned_estimator)
146+
124147

125148
class AndFunc(ABC):
126149
"""
@@ -133,18 +156,20 @@ def eval(self, xy_list: list) -> Xy:
133156

134157

135158
class AndNode(Node):
136-
__andfunc__ = None
137-
138159
def __init__(self, node_name: str, and_func: AndFunc):
139160
self.__node_name__ = node_name
140161
self.__andfunc__ = and_func
162+
self.__id__ = uuid.uuid4()
141163

142164
def get_and_func(self) -> AndFunc:
143165
return self.__andfunc__
144166

145167
def get_and_flag(self):
146168
return True
147169

170+
def clone(self):
171+
return AndNode(self.__node_name__, self.__andfunc__)
172+
148173

149174
class Edge:
150175
__from_node__ = None

codeflare/pipelines/Runtime.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,53 @@ class ExecutionType(Enum):
1818

1919

2020
@ray.remote
21-
def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, Xy: XYRef):
21+
def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, xy_ref: XYRef):
2222
estimator = node.get_estimator()
2323
# Blocking operation -- not avoidable
24-
X = ray.get(Xy.get_Xref())
25-
y = ray.get(Xy.get_yref())
24+
X = ray.get(xy_ref.get_Xref())
25+
y = ray.get(xy_ref.get_yref())
2626

2727
if train_mode == ExecutionType.FIT:
28+
cloned_node = node.clone()
29+
node_ptr = ray.put(cloned_node)
30+
2831
if base.is_classifier(estimator) or base.is_regressor(estimator):
2932
# Always clone before fit, else fit is invalid
30-
cloned_estimator = base.clone(estimator)
33+
cloned_estimator = cloned_node.get_estimator()
3134
cloned_estimator.fit(X, y)
3235
# TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3336
res_Xref = ray.put(cloned_estimator.predict(X))
34-
result = XYRef(res_Xref, Xy.get_yref())
37+
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
3538
return result
3639
else:
37-
# No need to clone as it is a transform pass through on the fitted estimator
38-
res_Xref = ray.put(estimator.fit_transform(X, y))
39-
result = XYRef(res_Xref, Xy.get_yref())
40+
cloned_estimator = cloned_node.get_estimator()
41+
res_Xref = ray.put(cloned_estimator.fit_transform(X, y))
42+
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
4043
return result
4144
elif train_mode == ExecutionType.SCORE:
45+
cloned_node = node.clone()
46+
node_ptr = ray.put(cloned_node)
47+
4248
if base.is_classifier(estimator) or base.is_regressor(estimator):
43-
cloned_estimator = base.clone(estimator)
49+
cloned_estimator = cloned_node.get_estimator()
4450
cloned_estimator.fit(X, y)
4551
res_Xref = ray.put(cloned_estimator.score(X, y))
46-
result = XYRef(res_Xref, Xy.get_yref())
52+
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
4753
return result
4854
else:
49-
# No need to clone as it is a transform pass through on the fitted estimator
50-
res_Xref = ray.put(estimator.fit_transform(X, y))
51-
result = XYRef(res_Xref, Xy.get_yref())
55+
cloned_estimator = cloned_node.get_estimator()
56+
res_Xref = ray.put(cloned_estimator.fit_transform(X, y))
57+
result = XYRef(res_Xref, xy_ref.get_yref(), node_ptr, [xy_ref])
5258
return result
5359
elif train_mode == ExecutionType.PREDICT:
5460
# Test mode does not clone as it is a simple predict or transform
5561
if base.is_classifier(estimator) or base.is_regressor(estimator):
5662
res_Xref = estimator.predict(X)
57-
result = XYRef(res_Xref, Xy.get_yref())
63+
result = XYRef(res_Xref, xy_ref.get_yref())
5864
return result
5965
else:
6066
res_Xref = estimator.transform(X)
61-
result = XYRef(res_Xref, Xy.get_yref())
67+
result = XYRef(res_Xref, xy_ref.get_yref())
6268
return result
6369

6470

@@ -78,29 +84,32 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
7884

7985

8086
@ray.remote
81-
def and_node_eval(and_func, Xyref_list):
87+
def and_node_eval(node: AndNode, Xyref_list):
8288
xy_list = []
8389
for Xyref in Xyref_list:
8490
X = ray.get(Xyref.get_Xref())
8591
y = ray.get(Xyref.get_yref())
8692
xy_list.append(Xy(X, y))
8793

88-
res_Xy = and_func.eval(xy_list)
94+
cloned_node = node.clone()
95+
node_ptr = ray.put(cloned_node)
96+
97+
cloned_and_func = cloned_node.get_and_func()
98+
res_Xy = cloned_and_func.eval(xy_list)
8999
res_Xref = ray.put(res_Xy.get_x())
90100
res_yref = ray.put(res_Xy.get_y())
91-
return XYRef(res_Xref, res_yref)
101+
return XYRef(res_Xref, res_yref, node_ptr, Xyref_list)
92102

93103

94104
def execute_and_node_inner(node: AndNode, Xyref_ptrs):
95-
and_func = node.get_and_func()
96105
result = []
97106

98107
Xyref_list = []
99108
for Xyref_ptr in Xyref_ptrs:
100109
Xyref = ray.get(Xyref_ptr)
101110
Xyref_list.append(Xyref)
102111

103-
Xyref_ptr = and_node_eval.remote(and_func, Xyref_list)
112+
Xyref_ptr = and_node_eval.remote(node, Xyref_list)
104113
result.append(Xyref_ptr)
105114
return result
106115

codeflare_pipelines.egg-info/SOURCES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
README.md
12
setup.py
23
codeflare/__init__.py
34
codeflare/pipelines/Datamodel.py

0 commit comments

Comments
 (0)