44import codeflare .pipelines .Exceptions as pe
55
66import sklearn .base as base
7+ from sklearn .model_selection import BaseCrossValidator
78from enum import Enum
89
910from queue import SimpleQueue
@@ -16,14 +17,14 @@ class ExecutionType(Enum):
1617
1718
1819@ray .remote
19- def execute_or_node_remote (node : dm .EstimatorNode , train_mode : ExecutionType , xy_ref : dm .XYRef ):
20+ def execute_or_node_remote (node : dm .EstimatorNode , mode : ExecutionType , xy_ref : dm .XYRef ):
2021 estimator = node .get_estimator ()
2122 # Blocking operation -- not avoidable
2223 X = ray .get (xy_ref .get_Xref ())
2324 y = ray .get (xy_ref .get_yref ())
2425
2526 # TODO: Can optimize the node pointers without replicating them
26- if train_mode == ExecutionType .FIT :
27+ if mode == ExecutionType .FIT :
2728 cloned_node = node .clone ()
2829 prev_node_ptr = ray .put (node )
2930
@@ -43,24 +44,17 @@ def execute_or_node_remote(node: dm.EstimatorNode, train_mode: ExecutionType, xy
4344 curr_node_ptr = ray .put (cloned_node )
4445 result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
4546 return result
46- elif train_mode == ExecutionType .SCORE :
47- cloned_node = node .clone ()
48- prev_node_ptr = ray .put (node )
49-
47+ elif mode == ExecutionType .SCORE :
5048 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
51- cloned_estimator = cloned_node .get_estimator ()
52- cloned_estimator .fit (X , y )
53- curr_node_ptr = ray .put (cloned_node )
54- res_Xref = ray .put (cloned_estimator .score (X , y ))
55- result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
49+ estimator = node .get_estimator ()
50+ res_Xref = ray .put (estimator .score (X , y ))
51+ result = dm .XYRef (res_Xref , xy_ref .get_yref ())
5652 return result
5753 else :
58- cloned_estimator = cloned_node .get_estimator ()
59- res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
60- curr_node_ptr = ray .put (cloned_node )
61- result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
54+ res_Xref = ray .put (estimator .transform (X ))
55+ result = dm .XYRef (res_Xref , xy_ref .get_yref ())
6256 return result
63- elif train_mode == ExecutionType .PREDICT :
57+ elif mode == ExecutionType .PREDICT :
6458 # Test mode does not clone as it is a simple predict or transform
6559 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
6660 res_Xref = estimator .predict (X )
@@ -136,11 +130,12 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
136130 edge_args [post_edge ].extend (exec_xyref_ptrs )
137131
138132
139- def execute_pipeline (pipeline : dm .Pipeline , mode : ExecutionType , in_args : dict ) :
133+ def execute_pipeline (pipeline : dm .Pipeline , mode : ExecutionType , pipeline_input : dm . PipelineInput ) -> dm . PipelineOutput :
140134 nodes_by_level = pipeline .get_nodes_by_level ()
141135
142136 # track args per edge
143137 edge_args = {}
138+ in_args = pipeline_input .get_in_args ()
144139 for node , node_in_args in in_args .items ():
145140 pre_edges = pipeline .get_pre_edges (node )
146141 for pre_edge in pre_edges :
@@ -161,10 +156,10 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, in_args: dict):
161156 edge = dm .Edge (last_level_node , None )
162157 out_args [last_level_node ] = edge_args [edge ]
163158
164- return out_args
159+ return dm . PipelineOutput ( out_args , edge_args )
165160
166161
167- def select_pipeline (chosen_xyref : dm .XYRef ):
162+ def select_pipeline (pipeline_output : dm . PipelineOutput , chosen_xyref : dm .XYRef ):
168163 pipeline = dm .Pipeline ()
169164 xyref_queue = SimpleQueue ()
170165
@@ -185,3 +180,86 @@ def select_pipeline(chosen_xyref: dm.XYRef):
185180 xyref_queue .put (prev_xyref )
186181
187182 return pipeline
183+
184+
185+ @ray .remote (num_returns = 2 )
186+ def split (cross_validator : BaseCrossValidator , xy_ref ):
187+ x = ray .get (xy_ref .get_Xref ())
188+ y = ray .get (xy_ref .get_yref ())
189+
190+ xy_train_refs = []
191+ xy_test_refs = []
192+
193+ for train_index , test_index in cross_validator .split (x , y ):
194+ x_train , x_test = x [train_index ], x [test_index ]
195+ y_train , y_test = y [train_index ], y [test_index ]
196+
197+ x_train_ref = ray .put (x_train )
198+ y_train_ref = ray .put (y_train )
199+ xy_train_ref = dm .XYRef (x_train_ref , y_train_ref )
200+ xy_train_refs .append (xy_train_ref )
201+
202+ x_test_ref = ray .put (x_test )
203+ y_test_ref = ray .put (y_test )
204+ xy_test_ref = dm .XYRef (x_test_ref , y_test_ref )
205+ xy_test_refs .append (xy_test_ref )
206+
207+ return xy_train_refs , xy_test_refs
208+
209+
210+ def cross_validate (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
211+ pipeline_input_train = dm .PipelineInput ()
212+
213+ pipeline_input_test = []
214+ k = cross_validator .get_n_splits ()
215+ # add k pipeline inputs for testing
216+ for i in range (k ):
217+ pipeline_input_test .append (dm .PipelineInput ())
218+
219+ in_args = pipeline_input .get_in_args ()
220+ for node , xyref_ptrs in in_args .items ():
221+ # NOTE: The assumption is that this node has only one input, the check earlier will ensure this!
222+ xyref_ptr = xyref_ptrs [0 ]
223+ xy_train_refs_ptr , xy_test_refs_ptr = split .remote (cross_validator , xyref_ptr )
224+ xy_train_refs = ray .get (xy_train_refs_ptr )
225+ xy_test_refs = ray .get (xy_test_refs_ptr )
226+
227+ for xy_train_ref in xy_train_refs :
228+ pipeline_input_train .add_xyref_arg (node , xy_train_ref )
229+
230+ # for testing, add only to the specific input
231+ for i in range (k ):
232+ pipeline_input_test [i ].add_xyref_arg (node , xy_test_refs [i ])
233+
234+ # Ready for execution now that data has been prepared! This execution happens in parallel
235+ # because of the underlying pipeline graph and multiple input objects
236+ pipeline_output_train = execute_pipeline (pipeline , ExecutionType .FIT , pipeline_input_train )
237+
238+ # Now we can choose the pipeline and then score for each of the chosen pipelines
239+ out_nodes = pipeline .get_terminal_nodes ()
240+ if len (out_nodes ) > 1 :
241+ raise pe .PipelineException ("Cannot cross validate as output is not a single node" )
242+
243+ out_node = out_nodes [0 ]
244+ out_xyref_ptrs = pipeline_output_train .get_xyrefs (out_node )
245+
246+ k = cross_validator .get_n_splits ()
247+ if len (out_xyref_ptrs ) != k :
248+ raise pe .PipelineException ("Number of outputs from pipeline fit is not equal to the folds from cross validator" )
249+
250+ pipeline_score_outputs = []
251+ # Below, jobs get submitted and then we can collect the results in the next loop
252+ for i in range (k ):
253+ selected_pipeline = select_pipeline (pipeline_output_train , out_xyref_ptrs [i ])
254+ selected_pipeline_output = execute_pipeline (selected_pipeline , ExecutionType .SCORE , pipeline_input_test [i ])
255+ pipeline_score_outputs .append (selected_pipeline_output )
256+
257+ result_scores = []
258+ for pipeline_score_output in pipeline_score_outputs :
259+ pipeline_out_xyrefs = pipeline_score_output .get_xyrefs (out_node )
260+ # again, only single xyref to be gotten out
261+ pipeline_out_xyref = pipeline_out_xyrefs [0 ]
262+ out_x = ray .get (pipeline_out_xyref .get_Xref ())
263+ result_scores .append (out_x )
264+
265+ return result_scores
0 commit comments