11import ray
22
3- from com .ibm .research .ray .graph .Datamodel import OrNode
4- from com .ibm .research .ray .graph .Datamodel import AndNode
5- from com .ibm .research .ray .graph .Datamodel import Edge
6- from com .ibm .research .ray .graph .Datamodel import Pipeline
7- from com .ibm .research .ray .graph .Datamodel import XYRef
3+ from codeflare .pipelines .Datamodel import OrNode
4+ from codeflare .pipelines .Datamodel import AndNode
5+ from codeflare .pipelines .Datamodel import Edge
6+ from codeflare .pipelines .Datamodel import Pipeline
7+ from codeflare .pipelines .Datamodel import XYRef
8+ from codeflare .pipelines .Datamodel import Xy
89
910import sklearn .base as base
1011from enum import Enum
@@ -30,92 +31,77 @@ def execute_or_node_inner(node: OrNode, train_mode: ExecutionType, Xy: XYRef):
3031 cloned_estimator .fit (X , y )
3132 # TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3233 res_Xref = ray .put (cloned_estimator .predict (X ))
33- result = [ XYRef (res_Xref , Xy .get_yref ())]
34+ result = XYRef (res_Xref , Xy .get_yref ())
3435 return result
3536 else :
3637 # No need to clone as it is a transform pass through on the fitted estimator
3738 res_Xref = ray .put (estimator .fit_transform (X , y ))
38- result = [ XYRef (res_Xref , Xy .get_yref ())]
39+ result = XYRef (res_Xref , Xy .get_yref ())
3940 return result
4041 elif train_mode == ExecutionType .SCORE :
4142 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
4243 cloned_estimator = base .clone (estimator )
4344 cloned_estimator .fit (X , y )
4445 res_Xref = ray .put (cloned_estimator .score (X , y ))
45- result = [ XYRef (res_Xref , Xy .get_yref ())]
46+ result = XYRef (res_Xref , Xy .get_yref ())
4647 return result
4748 else :
4849 # No need to clone as it is a transform pass through on the fitted estimator
4950 res_Xref = ray .put (estimator .fit_transform (X , y ))
50- result = [ XYRef (res_Xref , Xy .get_yref ())]
51+ result = XYRef (res_Xref , Xy .get_yref ())
5152 return result
5253 elif train_mode == ExecutionType .PREDICT :
5354 # Test mode does not clone as it is a simple predict or transform
5455 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
5556 res_Xref = estimator .predict (X )
56- result = [ XYRef (res_Xref , Xy .get_yref ())]
57+ result = XYRef (res_Xref , Xy .get_yref ())
5758 return result
5859 else :
5960 res_Xref = estimator .transform (X )
60- result = [ XYRef (res_Xref , Xy .get_yref ())]
61+ result = XYRef (res_Xref , Xy .get_yref ())
6162 return result
6263
6364
64- ###
65- # in_args is a dict from Node to list of XYRefs
66- ###
67- def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
68- nodes_by_level = pipeline .get_nodes_by_level ()
69-
70- # track args per edge
71- edge_args = {}
72- for node , node_in_args in in_args .items ():
73- pre_edges = pipeline .get_pre_edges (node )
74- for pre_edge in pre_edges :
75- edge_args [pre_edge ] = node_in_args
76-
77- for nodes in nodes_by_level :
78- for node in nodes :
79- pre_edges = pipeline .get_pre_edges (node )
80- post_edges = pipeline .get_post_edges (node )
81- if not node .get_and_flag ():
82- execute_or_node (node , pre_edges , edge_args , post_edges , mode )
83- else :
84- cross_product = execute_and_node (node , pre_edges , edge_args , post_edges )
85- for element in cross_product :
86- print (element )
87-
88- out_args = {}
89- last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
90- for last_level_node in last_level_nodes :
91- edge = Edge (last_level_node , None )
92- out_args [last_level_node ] = edge_args [edge ]
65+ def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
66+ for pre_edge in pre_edges :
67+ Xyref_ptrs = edge_args [pre_edge ]
68+ exec_xyrefs = []
69+ for xy_ref_ptr in Xyref_ptrs :
70+ xy_ref = ray .get (xy_ref_ptr )
71+ inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
72+ exec_xyrefs .append (inner_result )
9373
94- return out_args
74+ for post_edge in post_edges :
75+ if post_edge not in edge_args .keys ():
76+ edge_args [post_edge ] = []
77+ edge_args [post_edge ].extend (exec_xyrefs )
9578
9679
9780@ray .remote
98- def and_node_eval (and_func , xy_list ):
99- Xy = and_func .eval (xy_list )
100- res_Xref = ray .put (Xy .get_x ())
101- res_yref = ray .put (Xy .get_y ())
81+ def and_node_eval (and_func , Xyref_list ):
82+ xy_list = []
83+ for Xyref in Xyref_list :
84+ X = ray .get (Xyref .get_Xref ())
85+ y = ray .get (Xyref .get_yref ())
86+ xy_list .append (Xy (X , y ))
87+
88+ res_Xy = and_func .eval (xy_list )
89+ res_Xref = ray .put (res_Xy .get_x ())
90+ res_yref = ray .put (res_Xy .get_y ())
10291 return XYRef (res_Xref , res_yref )
10392
10493
105- def execute_and_node_inner (node : AndNode , elements ):
94+ def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
10695 and_func = node .get_and_func ()
10796 result = []
10897
109- for element in elements :
110- xy_list = []
111- for Xy in element :
112- X = ray .get (Xy .get_Xref ())
113- y = ray .get (Xy .get_yref ())
98+ Xyref_list = []
99+ for Xyref_ptr in Xyref_ptrs :
100+ Xyref = ray .get (Xyref_ptr )
101+ Xyref_list .append (Xyref )
114102
115- Xy = Xy (X , y )
116- xy_list .append (Xy )
117- Xyref = and_node_eval (and_func , xy_list )
118- result .append (Xyref )
103+ Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
104+ result .append (Xyref_ptr )
119105 return result
120106
121107
@@ -129,24 +115,36 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
129115 cross_product = itertools .product (* edge_args_lists )
130116
131117 for element in cross_product :
132- exec_xyrefs = execute_and_node_inner (node , element )
118+ exec_xyref_ptrs = execute_and_node_inner (node , element )
133119 for post_edge in post_edges :
134120 if post_edge not in edge_args .keys ():
135121 edge_args [post_edge ] = []
136- edge_args [post_edge ].extend (exec_xyrefs )
122+ edge_args [post_edge ].extend (exec_xyref_ptrs )
137123
138124
139- def execute_or_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
140- for pre_edge in pre_edges :
141- Xyrefs = edge_args [pre_edge ]
142- exec_xyrefs = []
143- for xy_ref in Xyrefs :
144- xy_ref_list = ray .get (xy_ref )
145- for xy_ref in xy_ref_list :
146- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
147- exec_xyrefs .append (inner_result )
125+ def execute_pipeline (pipeline : Pipeline , mode : ExecutionType , in_args : dict ):
126+ nodes_by_level = pipeline .get_nodes_by_level ()
148127
149- for post_edge in post_edges :
150- if post_edge not in edge_args .keys ():
151- edge_args [post_edge ] = []
152- edge_args [post_edge ].extend (exec_xyrefs )
128+ # track args per edge
129+ edge_args = {}
130+ for node , node_in_args in in_args .items ():
131+ pre_edges = pipeline .get_pre_edges (node )
132+ for pre_edge in pre_edges :
133+ edge_args [pre_edge ] = node_in_args
134+
135+ for nodes in nodes_by_level :
136+ for node in nodes :
137+ pre_edges = pipeline .get_pre_edges (node )
138+ post_edges = pipeline .get_post_edges (node )
139+ if not node .get_and_flag ():
140+ execute_or_node (node , pre_edges , edge_args , post_edges , mode )
141+ elif node .get_and_flag ():
142+ execute_and_node (node , pre_edges , edge_args , post_edges )
143+
144+ out_args = {}
145+ last_level_nodes = nodes_by_level [pipeline .compute_max_level ()]
146+ for last_level_node in last_level_nodes :
147+ edge = Edge (last_level_node , None )
148+ out_args [last_level_node ] = edge_args [edge ]
149+
150+ return out_args
0 commit comments