1- from sklearn .base import BaseEstimator
21from abc import ABC , abstractmethod
2+ from enum import Enum
3+
4+ import sklearn .base as base
5+ from sklearn .base import TransformerMixin
6+ from sklearn .base import BaseEstimator
37
8+ import ray
9+ import codeflare .pipelines .Exceptions as pe
410
511class Xy :
612 """
@@ -35,9 +41,12 @@ class XYRef:
3541 computed), these holders are essential to the pipeline constructs.
3642 """
3743
38- def __init__ (self , Xref , yref ):
44+ def __init__ (self , Xref , yref , prev_node_state_ref = None , curr_node_state_ref = None , prev_Xyrefs = None ):
3945 self .__Xref__ = Xref
4046 self .__yref__ = yref
47+ self .__prev_node_state_ref__ = prev_node_state_ref
48+ self .__curr_node_state_ref__ = curr_node_state_ref
49+ self .__prev_Xyrefs__ = prev_Xyrefs
4150
4251 def get_Xref (self ):
4352 """
@@ -51,6 +60,32 @@ def get_yref(self):
5160 """
5261 return self .__yref__
5362
63+ def get_prev_node_state_ref (self ):
64+ return self .__prev_node_state_ref__
65+
66+ def get_curr_node_state_ref (self ):
67+ return self .__curr_node_state_ref__
68+
69+ def get_prev_xyrefs (self ):
70+ return self .__prev_Xyrefs__
71+
72+
73+ class NodeInputType (Enum ):
74+ OR = 0 ,
75+ AND = 1
76+
77+
78+ class NodeFiringType (Enum ):
79+ ANY = 0 ,
80+ ALL = 1
81+
82+
83+ class NodeStateType (Enum ):
84+ STATELESS = 0 ,
85+ IMMUTABLE = 1 ,
86+ MUTABLE_SEQUENTIAL = 2 ,
87+ MUTABLE_AGGREGATE = 3
88+
5489
5590class Node (ABC ):
5691 """
@@ -59,12 +94,27 @@ class Node(ABC):
5994 node name and the type of the node match.
6095 """
6196
97+ def __init__ (self , node_name , node_input_type : NodeInputType , node_firing_type : NodeFiringType , node_state_type : NodeStateType ):
98+ self .__node_name__ = node_name
99+ self .__node_input_type__ = node_input_type
100+ self .__node_firing_type__ = node_firing_type
101+ self .__node_state_type__ = node_state_type
102+
62103 def __str__ (self ):
63104 return self .__node_name__
64105
106+ def get_node_input_type (self ):
107+ return self .__node_input_type__
108+
109+ def get_node_firing_type (self ):
110+ return self .__node_firing_type__
111+
112+ def get_node_state_type (self ):
113+ return self .__node_state_type__
114+
65115 @abstractmethod
66- def get_and_flag (self ):
67- raise NotImplementedError ("Please implement this method" )
116+ def clone (self ):
117+ raise NotImplementedError ("Please implement the clone method" )
68118
69119 def __hash__ (self ):
70120 """
@@ -88,12 +138,11 @@ def __eq__(self, other):
88138 )
89139
90140
91- class OrNode (Node ):
141+ class EstimatorNode (Node ):
92142 """
93143 Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
94144 stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
95145 """
96- __estimator__ = None
97146
98147 def __init__ (self , node_name : str , estimator : BaseEstimator ):
99148 """
@@ -102,7 +151,8 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
102151 :param node_name: Name of the node
103152 :param estimator: The base estimator
104153 """
105- self .__node_name__ = node_name
154+
155+ super ().__init__ (node_name , NodeInputType .OR , NodeFiringType .ANY , NodeStateType .IMMUTABLE )
106156 self .__estimator__ = estimator
107157
108158 def get_estimator (self ) -> BaseEstimator :
@@ -113,37 +163,33 @@ def get_estimator(self) -> BaseEstimator:
113163 """
114164 return self .__estimator__
115165
116- def get_and_flag (self ):
117- """
118- A flag to check if node is AND or not. By definition, this is NOT
119- an AND node.
120- :return: False, always
121- """
122- return False
166+ def clone (self ):
167+ cloned_estimator = base .clone (self .__estimator__ )
168+ return EstimatorNode (self .__node_name__ , cloned_estimator )
123169
124170
125- class AndFunc ( ABC ):
126- """
127- Or nodes are init-ed from the
128- """
171+ class AndTransform ( TransformerMixin , BaseEstimator ):
172+ @ abstractmethod
173+ def transform ( self , xy_list : list ) -> Xy :
174+ raise NotImplementedError ( "Please implement this method" )
129175
176+
177+ class GeneralTransform (TransformerMixin , BaseEstimator ):
130178 @abstractmethod
131- def eval (self , xy_list : list ) -> Xy :
179+ def transform (self , xy : Xy ) -> Xy :
132180 raise NotImplementedError ("Please implement this method" )
133181
134182
135183class AndNode (Node ):
136- __andfunc__ = None
137-
138- def __init__ (self , node_name : str , and_func : AndFunc ):
139- self .__node_name__ = node_name
184+ def __init__ (self , node_name : str , and_func : AndTransform ):
185+ super ().__init__ (node_name , NodeInputType .AND , NodeFiringType .ANY , NodeStateType .STATELESS )
140186 self .__andfunc__ = and_func
141187
142- def get_and_func (self ) -> AndFunc :
188+ def get_and_func (self ) -> AndTransform :
143189 return self .__andfunc__
144190
145- def get_and_flag (self ):
146- return True
191+ def clone (self ):
192+ return AndNode ( self . __node_name__ , self . __andfunc__ )
147193
148194
149195class Edge :
@@ -322,5 +368,69 @@ def get_post_edges(self, node: Node):
322368 return post_edges
323369
324370 def is_terminal (self , node : Node ):
325- node_post_edges = self .get_post_edges (node )
326- return len (node_post_edges ) == 0
371+ post_nodes = self .__post_graph__ [node ]
372+ return not post_nodes
373+
374+ def get_terminal_nodes (self ):
375+ # dict from level to nodes
376+ terminal_nodes = []
377+ for node in self .__pre_graph__ .keys ():
378+ if self .is_terminal (node ):
379+ terminal_nodes .append (node )
380+ return terminal_nodes
381+
382+
383+ class PipelineOutput :
384+ """
385+ Pipeline output to keep reference counters so that pipelines can be materialized
386+ """
387+ def __init__ (self , out_args , edge_args ):
388+ self .__out_args__ = out_args
389+ self .__edge_args__ = edge_args
390+
391+ def get_xyrefs (self , node : Node ):
392+ if node in self .__out_args__ :
393+ xyrefs_ptr = self .__out_args__ [node ]
394+ elif node in self .__edge_args__ :
395+ xyrefs_ptr = self .__edge_args__ [node ]
396+ else :
397+ raise pe .PipelineNodeNotFoundException ("Node " + str (node ) + " not found" )
398+
399+ xyrefs = ray .get (xyrefs_ptr )
400+ return xyrefs
401+
402+ def get_edge_args (self ):
403+ return self .__edge_args__
404+
405+
406+ class PipelineInput :
407+ """
408+ in_args is a dict from a node -> [Xy]
409+ """
410+ def __init__ (self ):
411+ self .__in_args__ = {}
412+
413+ def add_xyref_ptr_arg (self , node : Node , xyref_ptr ):
414+ if node not in self .__in_args__ :
415+ self .__in_args__ [node ] = []
416+
417+ self .__in_args__ [node ].append (xyref_ptr )
418+
419+ def add_xyref_arg (self , node : Node , xyref : XYRef ):
420+ if node not in self .__in_args__ :
421+ self .__in_args__ [node ] = []
422+
423+ xyref_ptr = ray .put (xyref )
424+ self .__in_args__ [node ].append (xyref_ptr )
425+
426+ def add_xy_arg (self , node : Node , xy : Xy ):
427+ if node not in self .__in_args__ :
428+ self .__in_args__ [node ] = []
429+
430+ x_ref = ray .put (xy .get_x ())
431+ y_ref = ray .put (xy .get_y ())
432+ xyref = XYRef (x_ref , y_ref )
433+ self .add_xyref_arg (node , xyref )
434+
435+ def get_in_args (self ):
436+ return self .__in_args__
0 commit comments