11import ray
22
3- from com .ibm .research .ray .graph .Datamodel import OrNode
4- from com .ibm .research .ray .graph .Datamodel import AndNode
5- from com .ibm .research .ray .graph .Datamodel import Edge
6- from com .ibm .research .ray .graph .Datamodel import Pipeline
7- from com .ibm .research .ray .graph .Datamodel import XYRef
3+ from codeflare .pipelines .Datamodel import OrNode
4+ from codeflare .pipelines .Datamodel import AndNode
5+ from codeflare .pipelines .Datamodel import Edge
6+ from codeflare .pipelines .Datamodel import Pipeline
7+ from codeflare .pipelines .Datamodel import XYRef
8+ from codeflare .pipelines .Datamodel import Xy
89
910import sklearn .base as base
1011from enum import Enum
1112
1213
1314class ExecutionType (Enum ):
14- TRAIN = 0 ,
15- TEST = 1
15+ FIT = 0 ,
16+ PREDICT = 1 ,
17+ SCORE = 2
1618
1719
1820@ray .remote
@@ -22,87 +24,84 @@ def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, Xy: XYRef):
2224 X = ray .get (Xy .get_Xref ())
2325 y = ray .get (Xy .get_yref ())
2426
25- if train_mode == ExecutionType .TRAIN :
27+ if train_mode == ExecutionType .FIT :
2628 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
2729 # Always clone before fit, else fit is invalid
2830 cloned_estimator = base .clone (estimator )
2931 cloned_estimator .fit (X , y )
3032 # TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3133 res_Xref = ray .put (cloned_estimator .predict (X ))
32- result = [ XYRef (res_Xref , Xy .get_yref ())]
34+ result = XYRef (res_Xref , Xy .get_yref ())
3335 return result
3436 else :
3537 # No need to clone as it is a transform pass through on the fitted estimator
36- res_Xref = ray .put (estimator .fit_transform (X ))
37- result = [ XYRef (res_Xref , Xy .get_yref ())]
38+ res_Xref = ray .put (estimator .fit_transform (X , y ))
39+ result = XYRef (res_Xref , Xy .get_yref ())
3840 return result
39- elif train_mode == ExecutionType .TEST :
41+ elif train_mode == ExecutionType .SCORE :
42+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
43+ cloned_estimator = base .clone (estimator )
44+ cloned_estimator .fit (X , y )
45+ res_Xref = ray .put (cloned_estimator .score (X , y ))
46+ result = XYRef (res_Xref , Xy .get_yref ())
47+ return result
48+ else :
49+ # No need to clone as it is a transform pass through on the fitted estimator
50+ res_Xref = ray .put (estimator .fit_transform (X , y ))
51+ result = XYRef (res_Xref , Xy .get_yref ())
52+ return result
53+ elif train_mode == ExecutionType .PREDICT :
4054 # Test mode does not clone as it is a simple predict or transform
4155 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
4256 res_Xref = estimator .predict (X )
43- result = [ XYRef (res_Xref , Xy .get_yref ())]
57+ result = XYRef (res_Xref , Xy .get_yref ())
4458 return result
4559 else :
4660 res_Xref = estimator .transform (X )
47- result = [ XYRef (res_Xref , Xy .get_yref ())]
61+ result = XYRef (res_Xref , Xy .get_yref ())
4862 return result
4963
5064
51- ###
52- # in_args is a dict from Node to list of XYRefs
53- ###
54- def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
55- nodes_by_level = pipeline .get_nodes_by_level ()
56-
57- # track args per edge
58- edge_args = {}
59- for node , node_in_args in in_args .items ():
60- pre_edges = pipeline .get_pre_edges (node )
61- for pre_edge in pre_edges :
62- edge_args [pre_edge ] = node_in_args
63-
64- for nodes in nodes_by_level :
65- for node in nodes :
66- pre_edges = pipeline .get_pre_edges (node )
67- post_edges = pipeline .get_post_edges (node )
68- if not node .get_and_flag ():
69- execute_or_node (node , pre_edges , edge_args , post_edges , mode )
70- else :
71- cross_product = execute_and_node (node , pre_edges , edge_args , post_edges )
72- for element in cross_product :
73- print (element )
74-
75- out_args = {}
76- last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
77- for last_level_node in last_level_nodes :
78- edge = Edge (last_level_node , None )
79- out_args [last_level_node ] = edge_args [edge ]
65+ def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
66+ for pre_edge in pre_edges :
67+ Xyref_ptrs = edge_args [pre_edge ]
68+ exec_xyrefs = []
69+ for xy_ref_ptr in Xyref_ptrs :
70+ xy_ref = ray .get (xy_ref_ptr )
71+ inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
72+ exec_xyrefs .append (inner_result )
8073
81- return out_args
74+ for post_edge in post_edges :
75+ if post_edge not in edge_args .keys ():
76+ edge_args [post_edge ] = []
77+ edge_args [post_edge ].extend (exec_xyrefs )
8278
8379
8480@ray .remote
85- def and_node_eval (and_func , xy_list ):
86- Xy = and_func .eval (xy_list )
87- res_Xref = ray .put (Xy .get_x ())
88- res_yref = ray .put (Xy .get_y ())
81+ def and_node_eval (and_func , Xyref_list ):
82+ xy_list = []
83+ for Xyref in Xyref_list :
84+ X = ray .get (Xyref .get_Xref ())
85+ y = ray .get (Xyref .get_yref ())
86+ xy_list .append (Xy (X , y ))
87+
88+ res_Xy = and_func .eval (xy_list )
89+ res_Xref = ray .put (res_Xy .get_x ())
90+ res_yref = ray .put (res_Xy .get_y ())
8991 return XYRef (res_Xref , res_yref )
9092
9193
92- def execute_and_node_inner (node : AndNode , elements ):
94+ def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
9395 and_func = node .get_and_func ()
9496 result = []
9597
96- for element in elements :
97- xy_list = []
98- for Xy in element :
99- X = ray .get (Xy .get_Xref ())
100- y = ray .get (Xy .get_yref ())
98+ Xyref_list = []
99+ for Xyref_ptr in Xyref_ptrs :
100+ Xyref = ray .get (Xyref_ptr )
101+ Xyref_list .append (Xyref )
101102
102- Xy = Xy (X , y )
103- xy_list .append (Xy )
104- Xyref = and_node_eval (and_func , xy_list )
105- result .append (Xyref )
103+ Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
104+ result .append (Xyref_ptr )
106105 return result
107106
108107
@@ -116,24 +115,36 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
116115 cross_product = itertools .product (* edge_args_lists )
117116
118117 for element in cross_product :
119- exec_xyrefs = execute_and_node_inner (node , element )
118+ exec_xyref_ptrs = execute_and_node_inner (node , element )
120119 for post_edge in post_edges :
121120 if post_edge not in edge_args .keys ():
122121 edge_args [post_edge ] = []
123- edge_args [post_edge ].extend (exec_xyrefs )
122+ edge_args [post_edge ].extend (exec_xyref_ptrs )
124123
125124
126- def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
127- for pre_edge in pre_edges :
128- Xyrefs = edge_args [pre_edge ]
129- exec_xyrefs = []
130- for xy_ref in Xyrefs :
131- xy_ref_list = ray .get (xy_ref )
132- for xy_ref in xy_ref_list :
133- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
134- exec_xyrefs .append (inner_result )
125+ def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
126+ nodes_by_level = pipeline .get_nodes_by_level ()
135127
136- for post_edge in post_edges :
137- if post_edge not in edge_args .keys ():
138- edge_args [post_edge ] = []
139- edge_args [post_edge ].extend (exec_xyrefs )
128+ # track args per edge
129+ edge_args = {}
130+ for node , node_in_args in in_args .items ():
131+ pre_edges = pipeline .get_pre_edges (node )
132+ for pre_edge in pre_edges :
133+ edge_args [pre_edge ] = node_in_args
134+
135+ for nodes in nodes_by_level :
136+ for node in nodes :
137+ pre_edges = pipeline .get_pre_edges (node )
138+ post_edges = pipeline .get_post_edges (node )
139+ if not node .get_and_flag ():
140+ execute_or_node (node , pre_edges , edge_args , post_edges , mode )
141+ elif node .get_and_flag ():
142+ execute_and_node (node , pre_edges , edge_args , post_edges )
143+
144+ out_args = {}
145+ last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
146+ for last_level_node in last_level_nodes :
147+ edge = Edge (last_level_node , None )
148+ out_args [last_level_node ] = edge_args [edge ]
149+
150+ return out_args
0 commit comments