11import ray
22
3- from codeflare .pipelines .Datamodel import OrNode
3+
4+ from codeflare .pipelines .Datamodel import EstimatorNode
45from codeflare .pipelines .Datamodel import AndNode
56from codeflare .pipelines .Datamodel import Edge
67from codeflare .pipelines .Datamodel import Pipeline
78from codeflare .pipelines .Datamodel import XYRef
89from codeflare .pipelines .Datamodel import Xy
10+ from codeflare .pipelines .Datamodel import NodeInputType
11+ from codeflare .pipelines .Datamodel import NodeStateType
12+ from codeflare .pipelines .Datamodel import NodeFiringType
913
1014import sklearn .base as base
1115from enum import Enum
@@ -18,47 +22,60 @@ class ExecutionType(Enum):
1822
1923
2024@ray .remote
21- def execute_or_node_inner (node : OrNode , train_mode : ExecutionType , Xy : XYRef ):
25+ def execute_or_node_remote (node : EstimatorNode , train_mode : ExecutionType , xy_ref : XYRef ):
2226 estimator = node .get_estimator ()
2327 # Blocking operation -- not avoidable
24- X = ray .get (Xy .get_Xref ())
25- y = ray .get (Xy .get_yref ())
28+ X = ray .get (xy_ref .get_Xref ())
29+ y = ray .get (xy_ref .get_yref ())
2630
31+ # TODO: Can optimize the node pointers without replicating them
2732 if train_mode == ExecutionType .FIT :
33+ cloned_node = node .clone ()
34+ prev_node_ptr = ray .put (node )
35+
2836 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
2937 # Always clone before fit, else fit is invalid
30- cloned_estimator = base . clone ( estimator )
38+ cloned_estimator = cloned_node . get_estimator ( )
3139 cloned_estimator .fit (X , y )
40+
41+ curr_node_ptr = ray .put (cloned_node )
3242 # TODO: For now, make yref passthrough - this has to be fixed more comprehensively
3343 res_Xref = ray .put (cloned_estimator .predict (X ))
34- result = XYRef (res_Xref , Xy .get_yref ())
44+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [ xy_ref ] )
3545 return result
3646 else :
37- # No need to clone as it is a transform pass through on the fitted estimator
38- res_Xref = ray .put (estimator .fit_transform (X , y ))
39- result = XYRef (res_Xref , Xy .get_yref ())
47+ cloned_estimator = cloned_node .get_estimator ()
48+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
49+ curr_node_ptr = ray .put (cloned_node )
50+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
4051 return result
4152 elif train_mode == ExecutionType .SCORE :
53+ cloned_node = node .clone ()
54+ prev_node_ptr = ray .put (node )
55+
4256 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
43- cloned_estimator = base . clone ( estimator )
57+ cloned_estimator = cloned_node . get_estimator ( )
4458 cloned_estimator .fit (X , y )
59+ curr_node_ptr = ray .put (cloned_node )
4560 res_Xref = ray .put (cloned_estimator .score (X , y ))
46- result = XYRef (res_Xref , Xy .get_yref ())
61+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [ xy_ref ] )
4762 return result
4863 else :
49- # No need to clone as it is a transform pass through on the fitted estimator
50- res_Xref = ray .put (estimator .fit_transform (X , y ))
51- result = XYRef (res_Xref , Xy .get_yref ())
64+ cloned_estimator = cloned_node .get_estimator ()
65+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
66+ curr_node_ptr = ray .put (cloned_node )
67+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
68+
5269 return result
5370 elif train_mode == ExecutionType .PREDICT :
5471 # Test mode does not clone as it is a simple predict or transform
5572 if base .is_classifier (estimator ) or base .is_regressor (estimator ):
5673 res_Xref = estimator .predict (X )
57- result = XYRef (res_Xref , Xy .get_yref ())
74+ result = XYRef (res_Xref , xy_ref .get_yref ())
5875 return result
5976 else :
6077 res_Xref = estimator .transform (X )
61- result = XYRef (res_Xref , Xy .get_yref ())
78+ result = XYRef (res_Xref , xy_ref .get_yref ())
6279 return result
6380
6481
@@ -68,7 +85,7 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
6885 exec_xyrefs = []
6986 for xy_ref_ptr in Xyref_ptrs :
7087 xy_ref = ray .get (xy_ref_ptr )
71- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
88+ inner_result = execute_or_node_remote .remote (node , mode , xy_ref )
7289 exec_xyrefs .append (inner_result )
7390
7491 for post_edge in post_edges :
@@ -78,29 +95,33 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
7895
7996
8097@ray .remote
81- def and_node_eval ( and_func , Xyref_list ):
98+ def execute_and_node_remote ( node : AndNode , Xyref_list ):
8299 xy_list = []
100+ prev_node_ptr = ray .put (node )
83101 for Xyref in Xyref_list :
84102 X = ray .get (Xyref .get_Xref ())
85103 y = ray .get (Xyref .get_yref ())
86104 xy_list .append (Xy (X , y ))
87105
88- res_Xy = and_func .eval (xy_list )
106+ cloned_node = node .clone ()
107+ curr_node_ptr = ray .put (cloned_node )
108+
109+ cloned_and_func = cloned_node .get_and_func ()
110+ res_Xy = cloned_and_func .transform (xy_list )
89111 res_Xref = ray .put (res_Xy .get_x ())
90112 res_yref = ray .put (res_Xy .get_y ())
91- return XYRef (res_Xref , res_yref )
113+ return XYRef (res_Xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
92114
93115
94116def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
95- and_func = node .get_and_func ()
96117 result = []
97118
98119 Xyref_list = []
99120 for Xyref_ptr in Xyref_ptrs :
100121 Xyref = ray .get (Xyref_ptr )
101122 Xyref_list .append (Xyref )
102123
103- Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
124+ Xyref_ptr = execute_and_node_remote .remote (node , Xyref_list )
104125 result .append (Xyref_ptr )
105126 return result
106127
@@ -136,9 +157,9 @@ def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
136157 for node in nodes :
137158 pre_edges = pipeline .get_pre_edges (node )
138159 post_edges = pipeline .get_post_edges (node )
139- if not node .get_and_flag () :
160+ if node .get_node_input_type () == NodeInputType . OR :
140161 execute_or_node (node , pre_edges , edge_args , post_edges , mode )
141- elif node .get_and_flag () :
162+ elif node .get_node_input_type () == NodeInputType . AND :
142163 execute_and_node (node , pre_edges , edge_args , post_edges )
143164
144165 out_args = {}
0 commit comments