@@ -18,47 +18,53 @@ class ExecutionType(Enum):
1818
1919
2020@ray .remote
21- def execute_or_node_inner (node : OrNode , train_mode : ExecutionType , Xy : XYRef ):
21+ def execute_or_node_inner (node : OrNode , train_mode : ExecutionType , xy_ref : XYRef ):
2222 estimator = node .get_estimator ()
2323 # Blocking operation -- not avoidable
24- X = ray .get (Xy .get_Xref ())
25- y = ray .get (Xy .get_yref ())
24+ X = ray .get (xy_ref .get_Xref ())
25+ y = ray .get (xy_ref .get_yref ())
2626
2727 if train_mode == ExecutionType .FIT :
28+ cloned_node = node .clone ()
29+ node_ptr = ray .put (cloned_node )
30+
2831 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
2932 # Always clone before fit, else fit is invalid
30- cloned_estimator = base . clone ( estimator )
33+ cloned_estimator = cloned_node . get_estimator ( )
3134 cloned_estimator .fit (X , y )
3235 # TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3336 res_Xref = ray .put (cloned_estimator .predict (X ))
34- result = XYRef (res_Xref , Xy .get_yref ())
37+ result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [ xy_ref ] )
3538 return result
3639 else :
37- # No need to clone as it is a transform pass through on the fitted estimator
38- res_Xref = ray .put (estimator .fit_transform (X , y ))
39- result = XYRef (res_Xref , Xy .get_yref ())
40+ cloned_estimator = cloned_node . get_estimator ()
41+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
42+ result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [ xy_ref ] )
4043 return result
4144 elif train_mode == ExecutionType .SCORE :
45+ cloned_node = node .clone ()
46+ node_ptr = ray .put (cloned_node )
47+
4248 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
43- cloned_estimator = base . clone ( estimator )
49+ cloned_estimator = cloned_node . get_estimator ( )
4450 cloned_estimator .fit (X , y )
4551 res_Xref = ray .put (cloned_estimator .score (X , y ))
46- result = XYRef (res_Xref , Xy .get_yref ())
52+ result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [ xy_ref ] )
4753 return result
4854 else :
49- # No need to clone as it is a transform pass through on the fitted estimator
50- res_Xref = ray .put (estimator .fit_transform (X , y ))
51- result = XYRef (res_Xref , Xy .get_yref ())
55+ cloned_estimator = cloned_node . get_estimator ()
56+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
57+ result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [ xy_ref ] )
5258 return result
5359 elif train_mode == ExecutionType .PREDICT :
5460 # Test mode does not clone as it is a simple predict or transform
5561 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
5662 res_Xref = estimator .predict (X )
57- result = XYRef (res_Xref , Xy .get_yref ())
63+ result = XYRef (res_Xref , xy_ref .get_yref ())
5864 return result
5965 else :
6066 res_Xref = estimator .transform (X )
61- result = XYRef (res_Xref , Xy .get_yref ())
67+ result = XYRef (res_Xref , xy_ref .get_yref ())
6268 return result
6369
6470
@@ -78,29 +84,32 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
7884
7985
8086@ray .remote
81- def and_node_eval (and_func , Xyref_list ):
87+ def and_node_eval (node : AndNode , Xyref_list ):
8288 xy_list = []
8389 for Xyref in Xyref_list :
8490 X = ray .get (Xyref .get_Xref ())
8591 y = ray .get (Xyref .get_yref ())
8692 xy_list .append (Xy (X , y ))
8793
88- res_Xy = and_func .eval (xy_list )
94+ cloned_node = node .clone ()
95+ node_ptr = ray .put (cloned_node )
96+
97+ cloned_and_func = cloned_node .get_and_func ()
98+ res_Xy = cloned_and_func .eval (xy_list )
8999 res_Xref = ray .put (res_Xy .get_x ())
90100 res_yref = ray .put (res_Xy .get_y ())
91- return XYRef (res_Xref , res_yref )
101+ return XYRef (res_Xref , res_yref , node_ptr , Xyref_list )
92102
93103
94104def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
95- and_func = node .get_and_func ()
96105 result = []
97106
98107 Xyref_list = []
99108 for Xyref_ptr in Xyref_ptrs :
100109 Xyref = ray .get (Xyref_ptr )
101110 Xyref_list .append (Xyref )
102111
103- Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
112+ Xyref_ptr = and_node_eval .remote (node , Xyref_list )
104113 result .append (Xyref_ptr )
105114 return result
106115
0 commit comments