@@ -619,6 +619,12 @@ def compute_node_levels(self):
619619 return self .__node_levels__
620620
621621 def get_node_level (self , node : Node ):
622+ """
623+ Returns the node level for the given node, a number between 0 and max_level (depth of the DAG/Pipeline).
624+
625+ :param node: Given node
626+ :return: Level between 0 and max_depth of pipeline
627+ """
622628 self .compute_node_levels ()
623629 return self .__node_levels__ [node ]
624630
@@ -657,9 +663,23 @@ def get_nodes_by_level(self):
657663 return self .__level_nodes__
658664
659665 def get_post_nodes (self , node : Node ):
666+ """
667+ Returns the nodes which are "below" the given node, i.e., have incoming edges from the
668+ given node, empty if it is an output node
669+
670+ :param node: Given node
671+ :return: List of nodes that have incoming edges to the given node
672+ """
660673 return self .__post_graph__ [node ]
661674
662675 def get_pre_nodes (self , node : Node ):
676+ """
677+ Returns the nodes which are "above" the given node, i.e., have outgoing edges from the given
678+ node, empty if it is an input node
679+
680+ :param node: Given node
681+ :return: List of nodes that have outgoing edges to the given node
682+ """
663683 return self .__pre_graph__ [node ]
664684
665685 def get_pre_edges (self , node : Node ):
@@ -697,10 +717,21 @@ def get_post_edges(self, node: Node):
697717 return post_edges
698718
699719 def is_output (self , node : Node ):
720+ """
721+ Checks if the given node is an output node
722+
723+ :param node: Given node
724+ :return: True if output else False
725+ """
700726 post_nodes = self .get_post_nodes (node )
701727 return not post_nodes
702728
703729 def get_output_nodes (self ):
730+ """
731+ Gets all the output nodes for this pipeline
732+
733+ :return: List of output nodes
734+ """
704735 # dict from level to nodes
705736 terminal_nodes = []
706737 for node in self .__pre_graph__ .keys ():
@@ -709,31 +740,29 @@ def get_output_nodes(self):
709740 return terminal_nodes
710741
711742 def get_nodes (self ):
712- return self .__node_name_map__
713-
714- def get_pre_nodes (self , node ):
715743 """
716- Get the nodes that have edges incoming to the given node
744+ Returns all the nodes of this pipeline in a dict from node_name to the node
717745
718- :param node: Given node
719- :return: List of nodes with incoming edges to the provided node
746+ :return: Dict of node_name to node
720747 """
721- return self .__pre_graph__ [ node ]
748+ return self .__node_name_map__
722749
723- def get_post_nodes (self , node ):
750+ def is_input (self , node : Node ):
724751 """
725- Get the nodes that have edges outgoing to the given node
752+ Checks if the given node is an input node of this pipeline
726753
727754 :param node: Given node
728- :return: List of nodes with outgoing edges from the provided node
755+ :return: True if input node else False
729756 """
730- return self .__post_graph__ [node ]
731-
732- def is_input (self , node : Node ):
733757 pre_nodes = self .get_pre_nodes (node )
734758 return not pre_nodes
735759
736760 def get_input_nodes (self ):
761+ """
762+ Returns all the input nodes of this pipeline
763+
764+ :return: List of input nodes
765+ """
737766 input_nodes = []
738767 for node in self .__node_name_map__ .values ():
739768 if self .get_node_level (node ) == 0 :
@@ -742,9 +771,21 @@ def get_input_nodes(self):
742771 return input_nodes
743772
744773 def get_node (self , node_name : str ) -> Node :
774+ """
775+ Return the node given a node name
776+
777+ :param node_name: Node name
778+ :return: The node with this node name
779+ """
745780 return self .__node_name_map__ [node_name ]
746781
747782 def has_single_estimator (self ):
783+ """
784+ Checks if this pipeline has only a single OR estimator, this is useful to know when picking a specific
785+ pipeline
786+
787+ :return: True if only one OR estimator else False
788+ """
748789 if len (self .get_output_nodes ()) > 1 :
749790 return False
750791
@@ -781,6 +822,41 @@ def save(self, filehandle):
781822 pickle .dump (saved_pipeline , filehandle )
782823
783824 def get_parameterized_pipeline (self , pipeline_param ):
825+ """
826+ Parameterizes the current pipeline with the provided pipeline_param and returns the newly parameterized
827+ pipeline. The pipeline_param is explored for all the parameters associated with a given node, which is
828+ then expanded to multiple nodes with generated node names. The graph is created using the existing
829+ connections, i.e. if there is an edge between node A and node B and with parameterization node B became
830+ node B1, node B2, an edge is created between node A and node B1 as well as node A and node B2.
831+
832+ Depending on the strategy of searches, the appropriate pipeline_param can create the right expansion.
833+ For example, grid search can expand the cross product of parameters and the pipeline will get expanded.
834+
835+ Examples
836+ --------
837+ The below code shows an example of how a 2 step pipeline gets expanded to a 9 node pipeline for grid
838+ search.
839+
840+ .. code-block:: python
841+
842+ pipeline = dm.Pipeline()
843+ node_pca = dm.EstimatorNode('pca', pca)
844+ node_logistic = dm.EstimatorNode('logistic', logistic)
845+
846+ pipeline.add_edge(node_pca, node_logistic)
847+
848+ param_grid = {
849+ 'pca__n_components': [5, 15, 30, 45, 64],
850+ 'logistic__C': np.logspace(-4, 4, 4),
851+ }
852+
853+ pipeline_param = dm.PipelineParam.from_param_grid(param_grid)
854+
855+ param_grid_pipeline = pipeline.get_parameterized_pipeline(pipeline_param)
856+
857+ :param pipeline_param: The pipeline parameters
858+ :return: A parameterized pipeline
859+ """
784860 result = Pipeline ()
785861 pipeline_params = pipeline_param .get_all_params ()
786862 parameterized_nodes = {}
0 commit comments