Skip to content

Commit 5e3462f

Browse files
Adding more docs, removing redundant methods in pipeline
1 parent 4033454 commit 5e3462f

File tree

1 file changed

+89
-13
lines changed

1 file changed

+89
-13
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,12 @@ def compute_node_levels(self):
619619
return self.__node_levels__
620620

621621
def get_node_level(self, node: Node):
622+
"""
623+
Returns the node level for the given node, a number between 0 and max_level (depth of the DAG/Pipeline).
624+
625+
:param node: Given node
626+
:return: Level between 0 and max_depth of pipeline
627+
"""
622628
self.compute_node_levels()
623629
return self.__node_levels__[node]
624630

@@ -657,9 +663,23 @@ def get_nodes_by_level(self):
657663
return self.__level_nodes__
658664

659665
def get_post_nodes(self, node: Node):
666+
"""
667+
Returns the nodes which are "below" the given node, i.e., have incoming edges from the
668+
given node, empty if it is an output node
669+
670+
:param node: Given node
671+
:return: List of nodes that have incoming edges to the given node
672+
"""
660673
return self.__post_graph__[node]
661674

662675
def get_pre_nodes(self, node: Node):
676+
"""
677+
Returns the nodes which are "above" the given node, i.e., have outgoing edges from the given
678+
node, empty if it is an input node
679+
680+
:param node: Given node
681+
:return: List of nodes that have outgoing edges to the given node
682+
"""
663683
return self.__pre_graph__[node]
664684

665685
def get_pre_edges(self, node: Node):
@@ -697,10 +717,21 @@ def get_post_edges(self, node: Node):
697717
return post_edges
698718

699719
def is_output(self, node: Node):
720+
"""
721+
Checks if the given node is an output node
722+
723+
:param node: Given node
724+
:return: True if output else False
725+
"""
700726
post_nodes = self.get_post_nodes(node)
701727
return not post_nodes
702728

703729
def get_output_nodes(self):
730+
"""
731+
Gets all the output nodes for this pipeline
732+
733+
:return: List of output nodes
734+
"""
704735
# dict from level to nodes
705736
terminal_nodes = []
706737
for node in self.__pre_graph__.keys():
@@ -709,31 +740,29 @@ def get_output_nodes(self):
709740
return terminal_nodes
710741

711742
def get_nodes(self):
712-
return self.__node_name_map__
713-
714-
def get_pre_nodes(self, node):
715743
"""
716-
Get the nodes that have edges incoming to the given node
744+
Returns all the nodes of this pipeline in a dict from node_name to the node
717745
718-
:param node: Given node
719-
:return: List of nodes with incoming edges to the provided node
746+
:return: Dict of node_name to node
720747
"""
721-
return self.__pre_graph__[node]
748+
return self.__node_name_map__
722749

723-
def get_post_nodes(self, node):
750+
def is_input(self, node: Node):
724751
"""
725-
Get the nodes that have edges outgoing to the given node
752+
Checks if the given node is an input node of this pipeline
726753
727754
:param node: Given node
728-
:return: List of nodes with outgoing edges from the provided node
755+
:return: True if input node else False
729756
"""
730-
return self.__post_graph__[node]
731-
732-
def is_input(self, node: Node):
733757
pre_nodes = self.get_pre_nodes(node)
734758
return not pre_nodes
735759

736760
def get_input_nodes(self):
761+
"""
762+
Returns all the input nodes of this pipeline
763+
764+
:return: List of input nodes
765+
"""
737766
input_nodes = []
738767
for node in self.__node_name_map__.values():
739768
if self.get_node_level(node) == 0:
@@ -742,9 +771,21 @@ def get_input_nodes(self):
742771
return input_nodes
743772

744773
def get_node(self, node_name: str) -> Node:
774+
"""
775+
Return the node given a node name
776+
777+
:param node_name: Node name
778+
:return: The node with this node name
779+
"""
745780
return self.__node_name_map__[node_name]
746781

747782
def has_single_estimator(self):
783+
"""
784+
Checks if this pipeline has only a single OR estimator, this is useful to know when picking a specific
785+
pipeline
786+
787+
:return: True if only one OR estimator else False
788+
"""
748789
if len(self.get_output_nodes()) > 1:
749790
return False
750791

@@ -781,6 +822,41 @@ def save(self, filehandle):
781822
pickle.dump(saved_pipeline, filehandle)
782823

783824
def get_parameterized_pipeline(self, pipeline_param):
825+
"""
826+
Parameterizes the current pipeline with the provided pipeline_param and returns the newly parameterized
827+
pipeline. The pipeline_param is explored for all the parameters associated with a given node, which is
828+
then expanded to multiple nodes with generated node names. The graph is created using the existing
829+
connections, i.e. if there is an edge between node A and node B and with parameterization node B became
830+
node B1, node B2, an edge is created between node A and node B1 as well as node A and node B2.
831+
832+
Depending on the strategy of searches, the appropriate pipeline_param can create the right expansion.
833+
For example, grid search can expand the cross product of parameters and the pipeline will get expanded.
834+
835+
Examples
836+
--------
837+
The below code shows an example of how a 2 step pipeline gets expanded to a 9 node pipeline for grid
838+
search.
839+
840+
.. code-block:: python
841+
842+
pipeline = dm.Pipeline()
843+
node_pca = dm.EstimatorNode('pca', pca)
844+
node_logistic = dm.EstimatorNode('logistic', logistic)
845+
846+
pipeline.add_edge(node_pca, node_logistic)
847+
848+
param_grid = {
849+
'pca__n_components': [5, 15, 30, 45, 64],
850+
'logistic__C': np.logspace(-4, 4, 4),
851+
}
852+
853+
pipeline_param = dm.PipelineParam.from_param_grid(param_grid)
854+
855+
param_grid_pipeline = pipeline.get_parameterized_pipeline(pipeline_param)
856+
857+
:param pipeline_param: The pipeline parameters
858+
:return: A parameterized pipeline
859+
"""
784860
result = Pipeline()
785861
pipeline_params = pipeline_param.get_all_params()
786862
parameterized_nodes = {}

0 commit comments

Comments
 (0)