@@ -199,6 +199,11 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
199199 self .__node_state_type__ = node_state_type
200200
201201 def __str__ (self ):
202+ """
203+ The string representation, which is the node name itself
204+
205+ :return: Node string
206+ """
202207 return self .__node_name__
203208
204209 def get_node_name (self ) -> str :
@@ -266,6 +271,17 @@ class EstimatorNode(Node):
266271
267272 This estimator node is typically an OR node, with ANY firing semantics, and IMMUTABLE state. For
268273 partial fit, we will have to define a different node type to keep semantics very clear.
274+
275+ .. code-block:: python
276+
277+ random_forest = RandomForestClassifier(n_estimators=200)
278+ node_rf = dm.EstimatorNode('randomforest', random_forest)
279+
280+ # get the estimator
281+ node_rf_estimator = node_rf.get_estimator()
282+
283+ # clone the node, clones the estimator as well
284+ node_rf_cloned = node_rf.clone()
269285 """
270286
271287 def __init__ (self , node_name : str , estimator : BaseEstimator ):
@@ -283,11 +299,16 @@ def get_estimator(self) -> BaseEstimator:
283299 """
284300 Return the estimator that this was initialize with
285301
286- :return: Estimator
302+ :return: Estimator that was initialized
287303 """
288304 return self .__estimator__
289305
290306 def clone (self ):
307+ """
308+ Clones the given node and the underlying estimator as well, if it was initialized with
309+
310+ :return: A cloned node
311+ """
291312 cloned_estimator = base .clone (self .__estimator__ )
292313 return EstimatorNode (self .__node_name__ , cloned_estimator )
293314
@@ -317,17 +338,31 @@ def clone(self):
317338
318339
319340class Edge :
320- __from_node__ = None
321- __to_node__ = None
341+ """
342+ An edge connects two nodes, it's an internal data structure for pipeline construction. An edge
343+ is a directed edge and has a "from_node" and a "to_node".
322344
345+ An edge also defines a hash function and an equality, where the equality is on the from and to
346+ node names being the same.
347+ """
323348 def __init__ (self , from_node : Node , to_node : Node ):
324349 self .__from_node__ = from_node
325350 self .__to_node__ = to_node
326351
327352 def get_from_node (self ) -> Node :
353+ """
354+ The from_node of this edge (originating node)
355+
356+ :return: The from_node of this edge
357+ """
328358 return self .__from_node__
329359
330360 def get_to_node (self ) -> Node :
361+ """
362+ The to_node of this edge (terminating node)
363+
364+ :return: The to_node of this edge
365+ """
331366 return self .__to_node__
332367
333368 def __str__ (self ):
@@ -361,7 +396,62 @@ def get_object_ref(self):
361396
362397class Pipeline :
363398 """
364- The pipeline class that defines the DAG structure composed of Node(s). The
399+ The pipeline class that defines the DAG structure composed of Node(s). This is the core data structure that
400+ defines the computation graph. A key note is that unlike SKLearn pipeline, CodeFlare pipelines are "abstract"
401+ graphs and get realized only when executed. Upon execution, they can potentially be multiple pathways in
402+ the pipeline, i.e. multiple "single" pipelines can be realized.
403+
404+ Examples
405+ --------
406+ Pipelines can be constructed quite simply using the builder paradigm with add_node and/or add_edge. In its
407+ simplest form, one can create nodes and then wire the DAG by adding edges. An example that does a simple
408+ pipeline is below:
409+
410+ .. code-block:: python
411+
412+ feature_union = FeatureUnion(transformer_list=[('PCA', PCA()),
413+ ('Nystroem', Nystroem()), ('SelectKBest', SelectKBest(k=3))])
414+ random_forest = RandomForestClassifier(n_estimators=200)
415+ node_fu = dm.EstimatorNode('feature_union', feature_union)
416+ node_rf = dm.EstimatorNode('randomforest', random_forest)
417+ pipeline.add_edge(node_fu, node_rf)
418+
419+ One can of course construct complex pipelines with multiple outgoing edges as well. An example of one that
420+ explores multiple models is shown below:
421+
422+ .. code-block:: python
423+
424+ preprocessor = ColumnTransformer(
425+ transformers=[
426+ ('num', numeric_transformer, numeric_features),
427+ ('cat', categorical_transformer, categorical_features)])
428+
429+ classifiers = [
430+ RandomForestClassifier(),
431+ GradientBoostingClassifier()
432+ ]
433+ pipeline = dm.Pipeline()
434+ node_pre = dm.EstimatorNode('preprocess', preprocessor)
435+ node_rf = dm.EstimatorNode('random_forest', classifiers[0])
436+ node_gb = dm.EstimatorNode('gradient_boost', classifiers[1])
437+
438+ pipeline.add_edge(node_pre, node_rf)
439+ pipeline.add_edge(node_pre, node_gb)
440+
441+ A pipeline can be saved and loaded, which in essence saves the "graph" and not the state of this pipeline.
442+ For saving the state of the pipeline, one can use the Runtime's save method! Save/load of pipeline uses
443+ Pickle protocol 5.
444+
445+ .. code-block:: python
446+
447+ fname = 'save_pipeline.cfp'
448+ fh = open(fname, 'wb')
449+ pipeline.save(fh)
450+ fh.close()
451+
452+ r_fh = open(fname, 'rb')
453+ saved_pipeline = dm.Pipeline.load(r_fh)
454+
365455 """
366456
367457 def __init__ (self ):
@@ -371,6 +461,12 @@ def __init__(self):
371461 self .__level_nodes__ = None
372462
373463 def add_node (self , node : Node ):
464+ """
465+ Adds a node to this pipeline
466+
467+ :param node: The node to add
468+ :return: None
469+ """
374470 self .__node_levels__ = None
375471 self .__level_nodes__ = None
376472 if node not in self .__pre_graph__ .keys ():
@@ -395,6 +491,13 @@ def get_str(nodes: list):
395491 return res
396492
397493 def add_edge (self , from_node : Node , to_node : Node ):
494+ """
495+ Adds an edge to this pipeline
496+
497+ :param from_node: The from node
498+ :param to_node: The to node
499+ :return: None
500+ """
398501 self .add_node (from_node )
399502 self .add_node (to_node )
400503
@@ -408,6 +511,14 @@ def get_postimage(self, node: Node):
408511 return self .__post_graph__ [node ]
409512
410513 def compute_node_level (self , node : Node , result : dict ):
514+ """
515+ Computes the node levels for a given node, an internal supporting function that is recursive, so it
516+ takes the result computed so far.
517+
518+ :param node: The node for which level needs to be computed
519+ :param result: The node levels that have already been computed
520+ :return: The level for this node
521+ """
411522 if node in result :
412523 return result [node ]
413524
@@ -426,6 +537,13 @@ def compute_node_level(self, node: Node, result: dict):
426537 return max_level + 1
427538
428539 def compute_node_levels (self ):
540+ """
541+ Computes node levels for all nodes. If a cache of node levels from previous calls exists, it will return
542+ the cache to avoid repeated computation.
543+
544+ :return: The mapping from node to its level as a dict
545+ """
546+ # TODO: This is incorrect when pipelines are mutable
429547 if self .__node_levels__ :
430548 return self .__node_levels__
431549
@@ -438,13 +556,24 @@ def compute_node_levels(self):
438556 return self .__node_levels__
439557
440558 def compute_max_level (self ):
559+ """
560+ Get the max depth of this pipeline graph.
561+
562+ :return: The max depth of pipeline
563+ """
441564 levels = self .compute_node_levels ()
442565 max_level = 0
443566 for node , node_level in levels .items ():
444567 max_level = max (node_level , max_level )
445568 return max_level
446569
447570 def get_nodes_by_level (self ):
571+ """
572+ A mapping from level to a list of nodes, useful for pipeline execution time. Similar to compute_levels,
573+ this method will return a cache if it exists, else will compute the levels and cache it.
574+
575+ :return: The mapping from level to a list of nodes at that level
576+ """
448577 if self .__level_nodes__ :
449578 return self .__level_nodes__
450579
@@ -460,16 +589,19 @@ def get_nodes_by_level(self):
460589 self .__level_nodes__ = result
461590 return self .__level_nodes__
462591
463- ###
464- # Get downstream node
465- ###
466592 def get_post_nodes (self , node : Node ):
467593 return self .__post_graph__ [node ]
468594
469595 def get_pre_nodes (self , node : Node ):
470596 return self .__pre_graph__ [node ]
471597
472598 def get_pre_edges (self , node : Node ):
599+ """
600+ Get the incoming edges to a specific node.
601+
602+ :param node: Given node
603+ :return: Incoming edges for the node
604+ """
473605 pre_edges = []
474606 pre_nodes = self .__pre_graph__ [node ]
475607 # Empty pre
@@ -481,6 +613,12 @@ def get_pre_edges(self, node: Node):
481613 return pre_edges
482614
483615 def get_post_edges (self , node : Node ):
616+ """
617+ Get the outgoing edges for the given node
618+
619+ :param node: Given node
620+ :return: Outgoing edges for the node
621+ """
484622 post_edges = []
485623 post_nodes = self .__post_graph__ [node ]
486624 # Empty post
@@ -492,10 +630,21 @@ def get_post_edges(self, node: Node):
492630 return post_edges
493631
494632 def is_terminal (self , node : Node ):
633+ """
634+ Checks if the given node is a terminal node, i.e. has no outgoing edges
635+
636+ :param node: Node to check terminal condition on
637+ :return: True if terminal else False
638+ """
495639 post_nodes = self .__post_graph__ [node ]
496640 return not post_nodes
497641
498642 def get_terminal_nodes (self ):
643+ """
644+ Get all the terminal nodes for this pipeline
645+
646+ :return: List of all terminal nodes
647+ """
499648 # dict from level to nodes
500649 terminal_nodes = []
501650 for node in self .__pre_graph__ .keys ():
@@ -504,18 +653,42 @@ def get_terminal_nodes(self):
504653 return terminal_nodes
505654
506655 def get_nodes (self ):
656+ """
657+ Get the nodes in this pipeline
658+
659+ :return: Node name to node dict
660+ """
507661 nodes = {}
508662 for node in self .__pre_graph__ .keys ():
509663 nodes [node .get_node_name ()] = node
510664 return nodes
511665
512666 def get_pre_nodes (self , node ):
667+ """
668+ Get the nodes that have edges incoming to the given node
669+
670+ :param node: Given node
671+ :return: List of nodes with incoming edges to the provided node
672+ """
513673 return self .__pre_graph__ [node ]
514674
515675 def get_post_nodes (self , node ):
676+ """
677+ Get the nodes that have edges outgoing to the given node
678+
679+ :param node: Given node
680+ :return: List of nodes with outgoing edges from the provided node
681+ """
516682 return self .__post_graph__ [node ]
517683
518684 def save (self , filehandle ):
685+ """
686+ Saves the pipeline graph (without state) to a file. A filehandle with write and binary mode
687+ is expected.
688+
689+ :param filehandle: Filehandle with wb mode
690+ :return: None
691+ """
519692 nodes = {}
520693 edges = []
521694
@@ -534,6 +707,12 @@ def save(self, filehandle):
534707
535708 @staticmethod
536709 def load (filehandle ):
710+ """
711+ Loads a pipeline that has been saved given the filehandle. Filehandle is in rb format.
712+
713+ :param filehandle: Filehandle to load pipeline from
714+ :return:
715+ """
537716 saved_pipeline = pickle .load (filehandle )
538717 if not isinstance (saved_pipeline , _SavedPipeline ):
539718 raise pe .PipelineException ("Filehandle is not a saved pipeline instance" )
@@ -551,14 +730,28 @@ def load(filehandle):
551730
552731
553732class _SavedPipeline :
733+ """
734+ Internal class that serializes the pipeline so that it can be pickled. As noted, this only captures
735+ the graph and not the state of the pipeline.
736+ """
554737 def __init__ (self , nodes , edges ):
555738 self .__nodes__ = nodes
556739 self .__edges__ = edges
557740
558741 def get_nodes (self ):
742+ """
743+ Nodes of the saved pipeline
744+
745+ :return: Dict of node name to node mapping
746+ """
559747 return self .__nodes__
560748
561749 def get_edges (self ):
750+ """
751+ Edges of the saved pipeline
752+
753+ :return: List of edges
754+ """
562755 return self .__edges__
563756
564757
0 commit comments