11from abc import ABC , abstractmethod
2- import uuid
32from enum import Enum
43
5-
64import sklearn .base as base
75from sklearn .base import TransformerMixin
86from sklearn .base import BaseEstimator
97
8+ import ray
9+ import codeflare .pipelines .Exceptions as pe
10+
1011class Xy :
1112 """
1213 Holder class for Xy, where X is array-like and y is array-like. This is the base
@@ -40,11 +41,11 @@ class XYRef:
4041 computed), these holders are essential to the pipeline constructs.
4142 """
4243
43- def __init__ (self , Xref , yref , prev_noderef = None , curr_noderef = None , prev_Xyrefs = None ):
44+ def __init__ (self , Xref , yref , prev_node_state_ref = None , curr_node_state_ref = None , prev_Xyrefs = None ):
4445 self .__Xref__ = Xref
4546 self .__yref__ = yref
46- self .__prevnoderef__ = prev_noderef
47- self .__currnoderef__ = curr_noderef
47+ self .__prev_node_state_ref__ = prev_node_state_ref
48+ self .__curr_node_state_ref__ = curr_node_state_ref
4849 self .__prev_Xyrefs__ = prev_Xyrefs
4950
5051 def get_Xref (self ):
@@ -59,11 +60,11 @@ def get_yref(self):
5960 """
6061 return self .__yref__
6162
62- def get_prevnoderef (self ):
63- return self .__prevnoderef__
63+ def get_prev_node_state_ref (self ):
64+ return self .__prev_node_state_ref__
6465
65- def get_currnoderef (self ):
66- return self .__currnoderef__
66+ def get_curr_node_state_ref (self ):
67+ return self .__curr_node_state_ref__
6768
6869 def get_prev_xyrefs (self ):
6970 return self .__prev_Xyrefs__
@@ -98,14 +99,10 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
9899 self .__node_input_type__ = node_input_type
99100 self .__node_firing_type__ = node_firing_type
100101 self .__node_state_type__ = node_state_type
101- self .__id__ = uuid .uuid4 ()
102102
103103 def __str__ (self ):
104104 return self .__node_name__
105105
106- def get_id (self ):
107- return self .__id__
108-
109106 def get_node_input_type (self ):
110107 return self .__node_input_type__
111108
@@ -125,8 +122,7 @@ def __hash__(self):
125122
126123 :return: Hash code
127124 """
128-
129- return self .__id__ .__hash__ ()
125+ return self .__node_name__ .__hash__ ()
130126
131127 def __eq__ (self , other ):
132128 """
@@ -138,7 +134,6 @@ def __eq__(self, other):
138134 """
139135 return (
140136 self .__class__ == other .__class__ and
141- self .__id__ == other .__id__ and
142137 self .__node_name__ == other .__node_name__
143138 )
144139
@@ -373,5 +368,69 @@ def get_post_edges(self, node: Node):
373368 return post_edges
374369
375370 def is_terminal (self , node : Node ):
376- node_post_edges = self .get_post_edges (node )
377- return len (node_post_edges ) == 0
371+ post_nodes = self .__post_graph__ [node ]
372+ return not post_nodes
373+
374+ def get_terminal_nodes (self ):
375+ # dict from level to nodes
376+ terminal_nodes = []
377+ for node in self .__pre_graph__ .keys ():
378+ if self .is_terminal (node ):
379+ terminal_nodes .append (node )
380+ return terminal_nodes
381+
382+
383+ class PipelineOutput :
384+ """
385+ Pipeline output to keep reference counters so that pipelines can be materialized
386+ """
387+ def __init__ (self , out_args , edge_args ):
388+ self .__out_args__ = out_args
389+ self .__edge_args__ = edge_args
390+
391+ def get_xyrefs (self , node : Node ):
392+ if node in self .__out_args__ :
393+ xyrefs_ptr = self .__out_args__ [node ]
394+ elif node in self .__edge_args__ :
395+ xyrefs_ptr = self .__edge_args__ [node ]
396+ else :
397+ raise pe .PipelineNodeNotFoundException ("Node " + str (node ) + " not found" )
398+
399+ xyrefs = ray .get (xyrefs_ptr )
400+ return xyrefs
401+
402+ def get_edge_args (self ):
403+ return self .__edge_args__
404+
405+
406+ class PipelineInput :
407+ """
408+ in_args is a dict from a node -> [Xy]
409+ """
410+ def __init__ (self ):
411+ self .__in_args__ = {}
412+
413+ def add_xyref_ptr_arg (self , node : Node , xyref_ptr ):
414+ if node not in self .__in_args__ :
415+ self .__in_args__ [node ] = []
416+
417+ self .__in_args__ [node ].append (xyref_ptr )
418+
419+ def add_xyref_arg (self , node : Node , xyref : XYRef ):
420+ if node not in self .__in_args__ :
421+ self .__in_args__ [node ] = []
422+
423+ xyref_ptr = ray .put (xyref )
424+ self .__in_args__ [node ].append (xyref_ptr )
425+
426+ def add_xy_arg (self , node : Node , xy : Xy ):
427+ if node not in self .__in_args__ :
428+ self .__in_args__ [node ] = []
429+
430+ x_ref = ray .put (xy .get_x ())
431+ y_ref = ray .put (xy .get_y ())
432+ xyref = XYRef (x_ref , y_ref )
433+ self .add_xyref_arg (node , xyref )
434+
435+ def get_in_args (self ):
436+ return self .__in_args__
0 commit comments