@@ -23,11 +23,11 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
2323 # Blocking operation -- not avoidable
2424 X = ray .get (xy_ref .get_Xref ())
2525 y = ray .get (xy_ref .get_yref ())
26+ prev_node_ptr = ray .put (node )
2627
2728 # TODO: Can optimize the node pointers without replicating them
2829 if mode == ExecutionType .FIT :
2930 cloned_node = node .clone ()
30- prev_node_ptr = ray .put (node )
3131
3232 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
3333 # Always clone before fit, else fit is invalid
@@ -49,22 +49,22 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
4949 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
5050 estimator = node .get_estimator ()
5151 res_Xref = ray .put (estimator .score (X , y ))
52- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
52+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
5353 return result
5454 else :
5555 res_Xref = ray .put (estimator .transform (X ))
56- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
56+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
5757
5858 return result
5959 elif mode == ExecutionType .PREDICT :
6060 # Test mode does not clone as it is a simple predict or transform
6161 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
6262 res_Xref = ray .put (estimator .predict (X ))
63- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
63+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
6464 return result
6565 else :
6666 res_Xref = ray .put (estimator .transform (X ))
67- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
67+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
6868 return result
6969
7070
@@ -84,38 +84,88 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
8484
8585
8686@ray .remote
87- def execute_and_node_remote (node : dm .AndNode , Xyref_list ):
87+ def execute_and_node_remote (node : dm .AndNode , mode : ExecutionType , Xyref_list ):
8888 xy_list = []
8989 prev_node_ptr = ray .put (node )
9090 for Xyref in Xyref_list :
9191 X = ray .get (Xyref .get_Xref ())
9292 y = ray .get (Xyref .get_yref ())
9393 xy_list .append (dm .Xy (X , y ))
9494
95- cloned_node = node .clone ()
96- curr_node_ptr = ray .put (cloned_node )
95+ estimator = node .get_estimator ()
96+
97+ # TODO: Can optimize the node pointers without replicating them
98+ if mode == ExecutionType .FIT :
99+ cloned_node = node .clone ()
100+
101+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
102+ # Always clone before fit, else fit is invalid
103+ cloned_estimator = cloned_node .get_estimator ()
104+ cloned_estimator .fit (xy_list )
97105
98- cloned_and_func = cloned_node .get_and_func ()
99- res_Xy = cloned_and_func .transform (xy_list )
100- res_Xref = ray .put (res_Xy .get_x ())
101- res_yref = ray .put (res_Xy .get_y ())
102- return dm .XYRef (res_Xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
106+ curr_node_ptr = ray .put (cloned_node )
107+ res_xy = cloned_estimator .predict (xy_list )
108+ res_xref = ray .put (res_xy .get_x ())
109+ res_yref = ray .put (res_xy .get_y ())
103110
111+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
112+ return result
113+ else :
114+ cloned_estimator = cloned_node .get_estimator ()
115+ res_xy = cloned_estimator .fit_transform (xy_list )
116+ res_xref = ray .put (res_xy .get_x ())
117+ res_yref = ray .put (res_xy .get_y ())
104118
105- def execute_and_node_inner (node : dm .AndNode , Xyref_ptrs ):
119+ curr_node_ptr = ray .put (cloned_node )
120+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
121+ return result
122+ elif mode == ExecutionType .SCORE :
123+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
124+ estimator = node .get_estimator ()
125+ res_xy = estimator .score (xy_list )
126+ res_xref = ray .put (res_xy .get_x ())
127+ res_yref = ray .put (res_xy .get_y ())
128+
129+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
130+ return result
131+ else :
132+ res_xy = estimator .transform (xy_list )
133+ res_xref = ray .put (res_xy .get_x ())
134+ res_yref = ray .put (res_xy .get_y ())
135+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
136+
137+ return result
138+ elif mode == ExecutionType .PREDICT :
139+ # Test mode does not clone as it is a simple predict or transform
140+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
141+ res_xy = estimator .predict (xy_list )
142+ res_xref = ray .put (res_xy .get_x ())
143+ res_yref = ray .put (res_xy .get_y ())
144+
145+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
146+ return result
147+ else :
148+ res_xy = estimator .transform (xy_list )
149+ res_xref = ray .put (res_xy .get_x ())
150+ res_yref = ray .put (res_xy .get_y ())
151+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
152+ return result
153+
154+
155+ def execute_and_node_inner (node : dm .AndNode , mode : ExecutionType , Xyref_ptrs ):
106156 result = []
107157
108158 Xyref_list = []
109159 for Xyref_ptr in Xyref_ptrs :
110160 Xyref = ray .get (Xyref_ptr )
111161 Xyref_list .append (Xyref )
112162
113- Xyref_ptr = execute_and_node_remote .remote (node , Xyref_list )
163+ Xyref_ptr = execute_and_node_remote .remote (node , mode , Xyref_list )
114164 result .append (Xyref_ptr )
115165 return result
116166
117167
118- def execute_and_node (node , pre_edges , edge_args , post_edges ):
168+ def execute_and_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
119169 edge_args_lists = list ()
120170 for pre_edge in pre_edges :
121171 edge_args_lists .append (edge_args [pre_edge ])
@@ -125,7 +175,7 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
125175 cross_product = itertools .product (* edge_args_lists )
126176
127177 for element in cross_product :
128- exec_xyref_ptrs = execute_and_node_inner (node , element )
178+ exec_xyref_ptrs = execute_and_node_inner (node , mode , element )
129179 for post_edge in post_edges :
130180 if post_edge not in edge_args .keys ():
131181 edge_args [post_edge ] = []
@@ -151,7 +201,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
151201 if node .get_node_input_type () == dm .NodeInputType .OR :
152202 execute_or_node (node , pre_edges , edge_args , post_edges , mode )
153203 elif node .get_node_input_type () == dm .NodeInputType .AND :
154- execute_and_node (node , pre_edges , edge_args , post_edges )
204+ execute_and_node (node , pre_edges , edge_args , post_edges , mode )
155205
156206 out_args = {}
157207 terminal_nodes = pipeline .get_output_nodes ()
@@ -249,7 +299,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
249299 raise pe .PipelineException ("Cross validation can only be done on pipelines with single estimator, "
250300 "use grid_search_cv instead" )
251301
252- result_grid_search_cv = grid_search_cv (cross_validator , pipeline , pipeline_input )
302+ result_grid_search_cv = _grid_search_cv (cross_validator , pipeline , pipeline_input )
253303 # only one output here
254304 result_scores = None
255305 for scores in result_grid_search_cv .values ():
@@ -259,7 +309,13 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
259309 return result_scores
260310
261311
262- def grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
312+ def grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput , pipeline_params : dm .PipelineParam ):
313+ parameterized_pipeline = pipeline .get_parameterized_pipeline (pipeline_params )
314+ parameterized_pipeline_input = pipeline_input .get_parameterized_input (pipeline , parameterized_pipeline )
315+ return _grid_search_cv (cross_validator , parameterized_pipeline , parameterized_pipeline_input )
316+
317+
318+ def _grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
263319 pipeline_input_train = dm .PipelineInput ()
264320
265321 pipeline_input_test = []
0 commit comments