From 14c83a69c85849476cf498c4f662fca07673f08c Mon Sep 17 00:00:00 2001 From: Robert McArthur Date: Wed, 21 Feb 2024 16:55:15 +1100 Subject: [PATCH] complete docstrings in numpy style --- src/spectral_cluster_supertree/scs/scs.py | 321 ++++++++++++++++------ 1 file changed, 234 insertions(+), 87 deletions(-) diff --git a/src/spectral_cluster_supertree/scs/scs.py b/src/spectral_cluster_supertree/scs/scs.py index 7cc10dd..db94163 100644 --- a/src/spectral_cluster_supertree/scs/scs.py +++ b/src/spectral_cluster_supertree/scs/scs.py @@ -29,22 +29,42 @@ def spectral_cluster_supertree( weights: Sequence[float] | None = None, random_state: np.random.RandomState = np.random.RandomState(), ) -> TreeNode: - """ - Spectral Cluster Supertree (SCS). + """Spectral Cluster Supertree (SCS). Constructs a supertree from a collection of input trees. The supertree - method is inspired by Min-Cut Supertree (Semple & Steel, 2000), using + method is inspired by Min-Cut Supertree [1_], using spectral clustering instead of min-cut to improve efficiency. The set of input trees must overlap, the optional weights parameter allows the biasing of some trees over others. - Args: - trees (Sequence[TreeNode]): Overlapping subtrees. - weights (Optional[Sequence[float]]): Optional weights for the trees. - - Returns: - TreeNode: The supertree containing all taxa in the input trees. + Parameters + ---------- + trees : Sequence[TreeNode] + The trees to find the supertree of. + pcg_weighting : Literal["one", "branch", "depth"], optional + The weighting strategy to use, by default "one". + normalise_pcg_weights : bool, optional + Whether to normalise the weights globally, by default False. + depth_normalisation : bool, optional + Whether to normalise the weights per tree, by default False. + contract_edges : bool, optional + Whether to contract the edges of the proper cluster graph, by default True. + weights : Sequence[float] | None, optional + The weights of the given trees, by default None. + random_state : np.random.RandomState, optional + Random number generation to use, by default np.random.RandomState(). + + Returns + ------- + TreeNode + The generated supertree. + + References + ---------- + .. [1] Semple, C., & Steel, M. (2000). + A supertree method for rooted trees. + Discrete Applied Mathematics, 105(1-3), 147-158. """ assert len(trees) >= 1, "there must be at least one tree" @@ -142,11 +162,30 @@ def spectral_cluster_supertree( def _denamify(tree: TreeNode): + """Remove all non-tip names in the trees. + + Parameters + ---------- + tree : TreeNode + The trees to remove internal node names of. + """ for node in tree.iter_nontips(include_self=True): node.name = None def _component_to_names_set(component: PcgVertexSet) -> set[Taxa]: + """Convert the vertex representation to a set of names of taxa. + + Parameters + ---------- + component : PcgVertexSet + A component of the proper cluster graph. + + Returns + ------- + set[Taxa] + A set of names of taxa in the component. + """ names_set: set[Taxa] = set() for c in component: names_set.update(c) @@ -158,16 +197,24 @@ def spectral_cluster_graph( edge_weights: dict[EdgeTuple, float], random_state: np.random.RandomState, ) -> list[PcgVertexSet]: - """ + """Partition the taxa through spectral clustering. + Given the proper cluster graph, perform Spectral Clustering to find the best partition of the vertices. - Args: - vertices (set): The set of vertices - edge_weights (dict[Frozenset, float]): The weights of the edges - - Returns: - list[set]: A bipartition of the vertices + Parameters + ---------- + vertices : PcgVertexSet + The vertices of the proper cluster graph. + edge_weights : dict[EdgeTuple, float] + The weights of the edges of the proper cluster graph. + random_state : np.random.RandomState + Random number generation for spectral clustering. + + Returns + ------- + list[PcgVertexSet] + The bipartion of taxa of the proper cluster graph. """ sc = SpectralClustering( 2, @@ -202,25 +249,34 @@ def _contract_proper_cluster_graph( taxa_occurrences: dict[PcgVertex, int], taxa_co_occurrences: dict[EdgeTuple, int], ) -> None: - """ + """Contracts the proper cluster graph. + This method operates in-place. - Given the proper cluster graph, contract every edge of maximal - weight (sum of the weights for the input trees). + Given the proper cluster graph, contract every edge where + two taxa always appear together. i.e. the number of co-occurences + as a proper cluster is equal to the maximum number of times either + taxa appears in any of the source trees. - The vertices for the contracted edges is a frozenset containing - the old vertices as elements. + The vertices for the contracted edges is a tuple containing + the taxa in the old vertices as elements (sorted). The weights for the parallel classes of edges formed through - contraction are calculated by the sum of the weights of the trees - that support at least one of those edges. - - Args: - vertices (set): The set of vertices - edges (dict): A mapping of vertices to other vertices they connect - edge_weights (dict[Frozenset, float]): The weight for each edge between two vertices - trees (Sequence[TreeNode]): The input trees - weights (Sequence[float]): The weights of the input trees + contraction are calculated by the maximum of the weights of the + trees that support at least one of those edges. + + Parameters + ---------- + vertices : PcgVertexSet + The vertices of the proper cluster graph (modified in-place). + edges : PcgEdgeMap + The edges of the proper cluster graph (modified in-place). + edge_weights : dict[EdgeTuple, float] + The weights of the edges of the proper cluster graph (modified in-place). + taxa_occurrences : dict[PcgVertex, int] + The number of times each taxon appears in any of the input trees. + taxa_co_occurrences : dict[EdgeTuple, int] + The number of times two taxa appear as a proper cluster. """ # Construct a new graph containing only the edges of maximal weight. # The components of this graph are the vertices following contraction @@ -307,14 +363,17 @@ def _contract_proper_cluster_graph( def _connect_trees(trees: Collection[TreeNode]) -> TreeNode: - """ - Connects the input trees by making them adjacent to a new root. + """Connects the input trees by making them adjacent to a new root. - Args: - trees (Iterable[TreeNode]): The input trees to connect + Parameters + ---------- + trees : Collection[TreeNode] + The input trees to connect. - Returns: - TreeNode: A tree connecting all the input trees + Returns + ------- + TreeNode + A tree connecting all the input trees. """ if len(trees) == 1: (one,) = trees # Unpack only tree @@ -326,8 +385,7 @@ def _connect_trees(trees: Collection[TreeNode]) -> TreeNode: def _generate_induced_trees_with_weights( names: set[Taxa], trees: Sequence[TreeNode], weights: Sequence[float] ) -> tuple[list[TreeNode], list[float]]: - """ - Induces the input trees on the set of names. + """Induces the input trees on the set of names. A tree can be induced on a set by removing all leaves that are not in the set. More concisely, inducing gives a subtree @@ -336,13 +394,20 @@ def _generate_induced_trees_with_weights( The results is a list of trees only expressing the given names and a list containing their corresponding weights. - Args: - names (set): The names to induce the trees on - trees (list[TreeNode]): The original trees to be induced - weights (list[float]): The corresponding weights of the trees - - Returns: - tuple[Sequence[TreeNode], Sequence[float]]: The induced trees and corresponding weights + Parameters + ---------- + names : set[Taxa] + The taxa to induce the trees on. + trees : Sequence[TreeNode] + The trees to induce. + weights : Sequence[float] + The weights of the trees. + + Returns + ------- + tuple[list[TreeNode], list[float]] + The induced trees. + The corresponding weights. """ induced_trees: list[TreeNode] = [] new_weights: list[float] = [] @@ -362,17 +427,19 @@ def _generate_induced_trees_with_weights( def _get_graph_components( vertices: PcgVertexSet, edges: PcgEdgeMap ) -> list[PcgVertexSet]: - """ - Given a graph expressed as a set of vertices and a dictionary of - edges (mapping vertices to sets of other vertices), find the - components of the graph. - - Args: - vertices (set): The set of edges. - edges (dict): A mapping of vertices to sets of other vertices. - - Returns: - list[set]: A list of sets of vertices, each element a component. + """Gather the components of a graph. + + Parameters + ---------- + vertices : PcgVertexSet + The vertices of the graph. + edges : PcgEdgeMap + The edges of the graph. + + Returns + ------- + list[PcgVertexSet] + A list of components of the proper cluster graph. """ components: list[PcgVertexSet] = [] @@ -396,14 +463,11 @@ def _proper_cluster_graph_edges( pcg_vertices: PcgVertexSet, trees: Sequence[TreeNode], weights: Sequence[float], - pcg_weighting: str, + pcg_weighting: Literal["one", "branch", "depth"], normalise_pcg_weights: bool, depth_normalisation: bool, ) -> tuple[ - PcgEdgeMap, - dict[EdgeTuple, float], - dict[PcgVertex, int], - dict[EdgeTuple, int], + PcgEdgeMap, dict[EdgeTuple, float], dict[PcgVertex, int], dict[EdgeTuple, int] ]: """Constructs a proper cluster graph for a collection of weighted trees. @@ -413,16 +477,31 @@ def _proper_cluster_graph_edges( The proper cluster graph contains all the leaves of the tree as vertices. An edge connects two vertices if they belong to a proper cluster in any - of the input trees. Each edge is weighted by the sum of the weights of - the trees for which the connected vertices are a proper cluster. - - Args: - pcg_vertices (set): The names of all leaves in the input trees - trees (Sequence[TreeNode]): The trees expressing the proper clusters - weights (Sequence[float]): The weight of each tree - - Returns: - tuple[dict, dict[Frozenset, float]]: The edges and weights of the edges + of the input trees. Each edge is weighted by the sum of the weights + according to the given weighting strategy. + + Parameters + ---------- + pcg_vertices : PcgVertexSet + The vertices of the proper cluster graph. + trees : Sequence[TreeNode] + The trees to construct the proper cluster graph from. + weights : Sequence[float] + Associated weights of each of the trees. + pcg_weighting : Literal["one", "branch", "depth"] + The weighting strategy to use. + normalise_pcg_weights : bool + Whether the weights of the proper cluster graph should be normalised globally based on longest branch. + depth_normalisation : bool + Whether the weights of the proper cluster graph should be normalised per tree based on maximum depth. + + Returns + ------- + tuple[PcgEdgeMap, dict[EdgeTuple, float], dict[PcgVertex, int], dict[EdgeTuple, int]] + The edges of the proper cluster graph. + The weights of the edges of the proper cluster graph. + The number of times each taxon appears in any of the input trees. + The number of times two taxa appear as a proper cluster. """ edges: PcgEdgeMap = {} edge_weights: dict[EdgeTuple, float] = {} @@ -460,7 +539,7 @@ def _proper_cluster_graph_edges( max_length = max(max_length, length) depth_normalisation_factor = max_length for side in tree: - side_taxa, max_sublength = dfs_pcg_weights( + side_taxa, max_sublength = _dfs_pcg_weights( edges, edge_weights, taxa_co_occurrences, @@ -480,7 +559,7 @@ def _proper_cluster_graph_edges( return edges, edge_weights, taxa_occurrences, taxa_co_occurrences -def dfs_pcg_weights( +def _dfs_pcg_weights( edges: PcgEdgeMap, edge_weights: dict[EdgeTuple, float], taxa_co_occurrences: dict[EdgeTuple, int], @@ -490,6 +569,37 @@ def dfs_pcg_weights( length_function: Callable[[float, PhyloNode], float], depth_normalisation_factor: int, ) -> tuple[list[PcgVertex], float]: + """Recusrive helper to construct the proper cluster graph from the tree in a DFS fashion. + + As all pairs of that are a descendant of an internal but on opposite sides have + the same wait, performing a DFS minimisises computational cost of constructing + the proper cluster graph. + + Parameters + ---------- + edges : PcgEdgeMap + The current edges of the proper cluster graph (modified in-place). + edge_weights : dict[EdgeTuple, float] + The weights of the edges of the proper cluster graph (modified in-place). + taxa_co_occurrences : dict[EdgeTuple, int] + The number of times two taxa appear as a proper cluster (modified in-place). + tree : PhyloNode + The tree/internal-node to construct the proper cluster graph from. + tree_weight : float + The associated weight of the tree. + length : float + The length from the internal node to the root of the tree. + length_function : Callable[[float, PhyloNode], float] + Function which applied the weighting strategy. + depth_normalisation_factor : int + Normalisation factor. + + Returns + ------- + tuple[list[PcgVertex], float] + All descendants of the current node. + The maximum root-to-tip distance. + """ if tree.is_tip(): tip_name: Taxa = tree.name # type: ignore return [(tip_name,)], 0.0 @@ -499,7 +609,7 @@ def dfs_pcg_weights( max_length = length children_tips: list[list[PcgVertex]] = [] for side in tree: - child_tips, normalise_length = dfs_pcg_weights( + child_tips, normalise_length = _dfs_pcg_weights( edges, edge_weights, taxa_co_occurrences, @@ -537,24 +647,57 @@ def dfs_pcg_weights( def edge_tuple(v1: PcgVertex, v2: PcgVertex) -> EdgeTuple: + """Generates an edge representing two taxa. + + Orders the taxa for consistent behaviour. + + Parameters + ---------- + v1 : PcgVertex + The first vertex. + v2 : PcgVertex + The second vertex. + + Returns + ------- + EdgeTuple + The unique edge representing these taxa. + """ if v1 < v2: return (v1, v2) return (v2, v1) def tuple_sorted(iterable: Iterable[Taxa]) -> PcgVertex: + """Generates a new vertex representing an iterable of taxa. + + Sorts the taxa then converts them into a tuple for predictable ordering. + + Parameters + ---------- + iterable : Iterable[Taxa] + An iterable of taxa. + + Returns + ------- + PcgVertex + A new vertex representing the group of taxa. + """ return tuple(sorted(iterable)) def _get_all_tip_names(trees: Iterable[TreeNode]) -> set[Taxa]: - """ - Fetch the tip names for some iterable of input trees. + """Collects all taxa names from an iterable of trees. - Args: - trees (Iterable[TreeNode]): Input trees. + Parameters + ---------- + trees : Iterable[TreeNode] + The trees to collect the taxa of. - Returns: - set: A set containing the tip names of the trees. + Returns + ------- + set[Taxa] + A set containing the tip names of the trees. """ names: set[Taxa] = set() for tree in trees: @@ -563,15 +706,19 @@ def _get_all_tip_names(trees: Iterable[TreeNode]) -> set[Taxa]: def _tip_names_to_tree(tip_names: Iterable[Taxa]) -> TreeNode: - """ - Convert an iterable of tip names to a tree. - The tip names are made adjacent to a new root. + """Generates a rooted tree of the taxa. + + All tip names are made adjacent to a new root node. - Args: - tip_names (Iterable): the names of the tips. + Parameters + ---------- + tip_names : Iterable[Taxa] + The names of the tips. - Returns: - TreeNode: A star tree with a root connecting each of the tip names + Returns + ------- + TreeNode + A star tree with a root connecting each of the tip names. """ tree_builder = TreeBuilder(constructor=TreeNode).create_edge # type: ignore tips = [tree_builder([], tip_name, {}) for tip_name in tip_names]