11import ray
22
3- from codeflare .pipelines .Datamodel import OrNode
3+ from codeflare .pipelines .Datamodel import EstimatorNode
44from codeflare .pipelines .Datamodel import AndNode
55from codeflare .pipelines .Datamodel import Edge
66from codeflare .pipelines .Datamodel import Pipeline
77from codeflare .pipelines .Datamodel import XYRef
88from codeflare .pipelines .Datamodel import Xy
9+ from codeflare .pipelines .Datamodel import NodeInputType
10+ from codeflare .pipelines .Datamodel import NodeStateType
11+ from codeflare .pipelines .Datamodel import NodeFiringType
912
1013import sklearn .base as base
1114from enum import Enum
@@ -18,43 +21,49 @@ class ExecutionType(Enum):
1821
1922
2023@ray .remote
21- def execute_or_node_inner (node : OrNode , train_mode : ExecutionType , xy_ref : XYRef ):
24+ def execute_or_node_remote (node : EstimatorNode , train_mode : ExecutionType , xy_ref : XYRef ):
2225 estimator = node .get_estimator ()
2326 # Blocking operation -- not avoidable
2427 X = ray .get (xy_ref .get_Xref ())
2528 y = ray .get (xy_ref .get_yref ())
2629
30+ # TODO: Can optimize the node pointers without replicating them
2731 if train_mode == ExecutionType .FIT :
2832 cloned_node = node .clone ()
29- node_ptr = ray .put (cloned_node )
33+ prev_node_ptr = ray .put (node )
3034
3135 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
3236 # Always clone before fit, else fit is invalid
3337 cloned_estimator = cloned_node .get_estimator ()
3438 cloned_estimator .fit (X , y )
39+
40+ curr_node_ptr = ray .put (cloned_node )
3541 # TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3642 res_Xref = ray .put (cloned_estimator .predict (X ))
37- result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [xy_ref ])
43+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
3844 return result
3945 else :
4046 cloned_estimator = cloned_node .get_estimator ()
4147 res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
42- result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [xy_ref ])
48+ curr_node_ptr = ray .put (cloned_node )
49+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
4350 return result
4451 elif train_mode == ExecutionType .SCORE :
4552 cloned_node = node .clone ()
46- node_ptr = ray .put (cloned_node )
53+ prev_node_ptr = ray .put (node )
4754
4855 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
4956 cloned_estimator = cloned_node .get_estimator ()
5057 cloned_estimator .fit (X , y )
58+ curr_node_ptr = ray .put (cloned_node )
5159 res_Xref = ray .put (cloned_estimator .score (X , y ))
52- result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [xy_ref ])
60+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
5361 return result
5462 else :
5563 cloned_estimator = cloned_node .get_estimator ()
5664 res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
57- result = XYRef (res_Xref , xy_ref .get_yref (), node_ptr , [xy_ref ])
65+ curr_node_ptr = ray .put (cloned_node )
66+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
5867 return result
5968 elif train_mode == ExecutionType .PREDICT :
6069 # Test mode does not clone as it is a simple predict or transform
@@ -74,7 +83,7 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
7483 exec_xyrefs = []
7584 for xy_ref_ptr in Xyref_ptrs :
7685 xy_ref = ray .get (xy_ref_ptr )
77- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
86+ inner_result = execute_or_node_remote .remote (node , mode , xy_ref )
7887 exec_xyrefs .append (inner_result )
7988
8089 for post_edge in post_edges :
@@ -84,21 +93,22 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
8493
8594
8695@ray .remote
87- def and_node_eval (node : AndNode , Xyref_list ):
96+ def execute_and_node_remote (node : AndNode , Xyref_list ):
8897 xy_list = []
98+ prev_node_ptr = ray .put (node )
8999 for Xyref in Xyref_list :
90100 X = ray .get (Xyref .get_Xref ())
91101 y = ray .get (Xyref .get_yref ())
92102 xy_list .append (Xy (X , y ))
93103
94104 cloned_node = node .clone ()
95- node_ptr = ray .put (cloned_node )
105+ curr_node_ptr = ray .put (cloned_node )
96106
97107 cloned_and_func = cloned_node .get_and_func ()
98- res_Xy = cloned_and_func .eval (xy_list )
108+ res_Xy = cloned_and_func .transform (xy_list )
99109 res_Xref = ray .put (res_Xy .get_x ())
100110 res_yref = ray .put (res_Xy .get_y ())
101- return XYRef (res_Xref , res_yref , node_ptr , Xyref_list )
111+ return XYRef (res_Xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
102112
103113
104114def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
@@ -109,7 +119,7 @@ def execute_and_node_inner(node: AndNode, Xyref_ptrs):
109119 Xyref = ray .get (Xyref_ptr )
110120 Xyref_list .append (Xyref )
111121
112- Xyref_ptr = and_node_eval .remote (node , Xyref_list )
122+ Xyref_ptr = execute_and_node_remote .remote (node , Xyref_list )
113123 result .append (Xyref_ptr )
114124 return result
115125
@@ -145,9 +155,9 @@ def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
145155 for node in nodes :
146156 pre_edges = pipeline .get_pre_edges (node )
147157 post_edges = pipeline .get_post_edges (node )
148- if not node .get_and_flag () :
158+ if node .get_node_input_type () == NodeInputType . OR :
149159 execute_or_node (node , pre_edges , edge_args , post_edges , mode )
150- elif node .get_and_flag () :
160+ elif node .get_node_input_type () == NodeInputType . AND :
151161 execute_and_node (node , pre_edges , edge_args , post_edges )
152162
153163 out_args = {}
0 commit comments