88from enum import Enum
99
1010from queue import SimpleQueue
11+ import pandas as pd
1112
1213
1314class ExecutionType (Enum ):
@@ -219,8 +220,15 @@ def split(cross_validator: BaseCrossValidator, xy_ref):
219220 xy_test_refs = []
220221
221222 for train_index , test_index in cross_validator .split (x , y ):
222- x_train , x_test = x [train_index ], x [test_index ]
223- y_train , y_test = y [train_index ], y [test_index ]
223+ if isinstance (x , pd .DataFrame ) or isinstance (x , pd .Series ):
224+ x_train , x_test = x .iloc [train_index ], x .iloc [test_index ]
225+ else :
226+ x_train , x_test = x [train_index ], x [test_index ]
227+
228+ if isinstance (y , pd .DataFrame ) or isinstance (y , pd .Series ):
229+ y_train , y_test = y .iloc [train_index ], y .iloc [test_index ]
230+ else :
231+ y_train , y_test = y [train_index ], y [test_index ]
224232
225233 x_train_ref = ray .put (x_train )
226234 y_train_ref = ray .put (y_train )
@@ -236,64 +244,22 @@ def split(cross_validator: BaseCrossValidator, xy_ref):
236244
237245
238246def cross_validate (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
239- pipeline_input_train = dm .PipelineInput ()
240-
241- pipeline_input_test = []
242- k = cross_validator .get_n_splits ()
243- # add k pipeline inputs for testing
244- for i in range (k ):
245- pipeline_input_test .append (dm .PipelineInput ())
246-
247- in_args = pipeline_input .get_in_args ()
248- for node , xyref_ptrs in in_args .items ():
249- # NOTE: The assumption is that this node has only one input!
250- xyref_ptr = xyref_ptrs [0 ]
251- xy_train_refs_ptr , xy_test_refs_ptr = split .remote (cross_validator , xyref_ptr )
252- xy_train_refs = ray .get (xy_train_refs_ptr )
253- xy_test_refs = ray .get (xy_test_refs_ptr )
254-
255- for xy_train_ref in xy_train_refs :
256- pipeline_input_train .add_xyref_arg (node , xy_train_ref )
257-
258- # for testing, add only to the specific input
259- for i in range (k ):
260- pipeline_input_test [i ].add_xyref_arg (node , xy_test_refs [i ])
261-
262- # Ready for execution now that data has been prepared! This execution happens in parallel
263- # because of the underlying pipeline graph and multiple input objects
264- pipeline_output_train = execute_pipeline (pipeline , ExecutionType .FIT , pipeline_input_train )
265-
266- # Now we can choose the pipeline and then score for each of the chosen pipelines
267- out_nodes = pipeline .get_output_nodes ()
268- if len (out_nodes ) > 1 :
269- raise pe .PipelineException ("Cannot cross validate as output is not a single node" )
270-
271- out_node = out_nodes [0 ]
272- out_xyref_ptrs = pipeline_output_train .get_xyrefs (out_node )
273-
274- k = cross_validator .get_n_splits ()
275- if len (out_xyref_ptrs ) != k :
276- raise pe .PipelineException ("Number of outputs from pipeline fit is not equal to the folds from cross validator" )
277-
278- pipeline_score_outputs = []
279- # Below, jobs get submitted and then we can collect the results in the next loop
280- for i in range (k ):
281- selected_pipeline = select_pipeline (pipeline_output_train , out_xyref_ptrs [i ])
282- selected_pipeline_output = execute_pipeline (selected_pipeline , ExecutionType .SCORE , pipeline_input_test [i ])
283- pipeline_score_outputs .append (selected_pipeline_output )
284-
285- result_scores = []
286- for pipeline_score_output in pipeline_score_outputs :
287- pipeline_out_xyrefs = pipeline_score_output .get_xyrefs (out_node )
288- # again, only single xyref to be gotten out
289- pipeline_out_xyref = pipeline_out_xyrefs [0 ]
290- out_x = ray .get (pipeline_out_xyref .get_Xref ())
291- result_scores .append (out_x )
247+ has_single_estimator = pipeline .has_single_estimator ()
248+ if not has_single_estimator :
249+ raise pe .PipelineException ("Cross validation can only be done on pipelines with single estimator, "
250+ "use grid_search_cv instead" )
251+
252+ result_grid_search_cv = grid_search_cv (cross_validator , pipeline , pipeline_input )
253+ # only one output here
254+ result_scores = None
255+ for scores in result_grid_search_cv .values ():
256+ result_scores = scores
257+ break
292258
293259 return result_scores
294260
295261
296- def grid_search (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
262+ def grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
297263 pipeline_input_train = dm .PipelineInput ()
298264
299265 pipeline_input_test = []
@@ -303,18 +269,24 @@ def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipe
303269 pipeline_input_test .append (dm .PipelineInput ())
304270
305271 in_args = pipeline_input .get_in_args ()
272+ # Keep a map from the pointer of train to test
273+ train_test_mapper = {}
274+
306275 for node , xyref_ptrs in in_args .items ():
307276 # NOTE: The assumption is that this node has only one input!
308277 xyref_ptr = xyref_ptrs [0 ]
309278 if len (xyref_ptrs ) > 1 :
310- raise pe .PipelineException ("Input to grid search is multiple objects, re-run with only single object" )
279+ raise pe .PipelineException ("Grid search supports single object input only, multiple provided, number is " + str ( len ( xyref_ptrs )) )
311280
312281 xy_train_refs_ptr , xy_test_refs_ptr = split .remote (cross_validator , xyref_ptr )
313282 xy_train_refs = ray .get (xy_train_refs_ptr )
314283 xy_test_refs = ray .get (xy_test_refs_ptr )
315284
316- for xy_train_ref in xy_train_refs :
285+ for i in range (len (xy_train_refs )):
286+ xy_train_ref = xy_train_refs [i ]
287+ xy_test_ref = xy_test_refs [i ]
317288 pipeline_input_train .add_xyref_arg (node , xy_train_ref )
289+ train_test_mapper [xy_train_ref ] = xy_test_ref
318290
319291 # for testing, add only to the specific input
320292 for i in range (k ):
@@ -324,9 +296,42 @@ def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipe
324296 # because of the underlying pipeline graph and multiple input objects
325297 pipeline_output_train = execute_pipeline (pipeline , ExecutionType .FIT , pipeline_input_train )
326298
327- # For grid search, we will have multiple output nodes that need to be iterated on and select the pipeline
328- # that is "best"
299+ # For grid search, we will have multiple output nodes that need to be iterated on
300+ selected_pipeline_test_outputs = {}
329301 out_nodes = pipeline .get_output_nodes ()
302+ for out_node in out_nodes :
303+ out_node_xyrefs = pipeline_output_train .get_xyrefs (out_node )
304+ for out_node_xyref in out_node_xyrefs :
305+ selected_pipeline = select_pipeline (pipeline_output_train , out_node_xyref )
306+ selected_pipeline_input = get_pipeline_input (pipeline , pipeline_output_train , out_node_xyref )
307+ selected_pipeline_inargs = selected_pipeline_input .get_in_args ()
308+ test_pipeline_input = dm .PipelineInput ()
309+ for node , train_xyref_ptr in selected_pipeline_inargs .items ():
310+ # xyrefs is a singleton by construction
311+ train_xyrefs = ray .get (train_xyref_ptr )
312+ test_xyref = train_test_mapper [train_xyrefs [0 ]]
313+ test_pipeline_input .add_xyref_arg (node , test_xyref )
314+ selected_pipeline_test_output = execute_pipeline (selected_pipeline , ExecutionType .SCORE , test_pipeline_input )
315+ if selected_pipeline not in selected_pipeline_test_outputs .keys ():
316+ selected_pipeline_test_outputs [selected_pipeline ] = []
317+ selected_pipeline_test_outputs [selected_pipeline ].append (selected_pipeline_test_output )
318+
319+ # now, test outputs can be materialized
320+ result_scores = {}
321+ for selected_pipeline , selected_pipeline_test_output_list in selected_pipeline_test_outputs .items ():
322+ output_nodes = selected_pipeline .get_output_nodes ()
323+ # by design, output_nodes will only have one node
324+ output_node = output_nodes [0 ]
325+ for selected_pipeline_test_output in selected_pipeline_test_output_list :
326+ pipeline_out_xyrefs = selected_pipeline_test_output .get_xyrefs (output_node )
327+ # again, only single xyref to be gotten out
328+ pipeline_out_xyref = pipeline_out_xyrefs [0 ]
329+ out_x = ray .get (pipeline_out_xyref .get_Xref ())
330+ if selected_pipeline not in result_scores .keys ():
331+ result_scores [selected_pipeline ] = []
332+ result_scores [selected_pipeline ].append (out_x )
333+
334+ return result_scores
330335
331336
332337def save (pipeline_output : dm .PipelineOutput , xy_ref : dm .XYRef , filehandle ):
0 commit comments