Skip to content

Commit

Permalink
enhance typing with additional aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcar17 committed Feb 21, 2024
1 parent 1f9105f commit f6126e0
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions src/spectral_cluster_supertree/scs/scs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

Taxa = NewType("Taxa", str)
PcgVertex: TypeAlias = tuple[Taxa, ...]
PcgVertexSet: TypeAlias = set[PcgVertex]
PcgEdgeMap: TypeAlias = dict[PcgVertex, PcgVertexSet]
EdgeTuple: TypeAlias = tuple[PcgVertex, PcgVertex]


Expand Down Expand Up @@ -68,7 +70,7 @@ def spectral_cluster_supertree(
tree = _tip_names_to_tree(all_names)
return tree

pcg_vertices: set[PcgVertex] = set((name,) for name in all_names)
pcg_vertices: PcgVertexSet = set((name,) for name in all_names)

(
pcg_edges,
Expand Down Expand Up @@ -145,18 +147,18 @@ def _denamify(tree: TreeNode):
node.name = None


def _component_to_names_set(component: set[PcgVertex]) -> set[Taxa]:
def _component_to_names_set(component: PcgVertexSet) -> set[Taxa]:
names_set: set[Taxa] = set()
for c in component:
names_set.update(c)
return names_set


def spectral_cluster_graph(
vertices: set[PcgVertex],
vertices: PcgVertexSet,
edge_weights: dict[EdgeTuple, float],
random_state: np.random.RandomState,
) -> list[set[PcgVertex]]:
) -> list[PcgVertexSet]:
"""
Given the proper cluster graph, perform Spectral Clustering
to find the best partition of the vertices.
Expand Down Expand Up @@ -187,16 +189,16 @@ def spectral_cluster_graph(

idxs = sc.fit_predict(edges)

partition: list[set[PcgVertex]] = [set(), set()]
partition: list[PcgVertexSet] = [set(), set()]
for vertex, idx in zip(vertex_list, idxs):
partition[idx].add(vertex)

return partition


def _contract_proper_cluster_graph(
vertices: set[PcgVertex],
edges: dict[PcgVertex, set[PcgVertex]],
vertices: PcgVertexSet,
edges: PcgEdgeMap,
edge_weights: dict[EdgeTuple, float],
taxa_occurrences: dict[PcgVertex, int],
taxa_co_occurrences: dict[EdgeTuple, int],
Expand All @@ -223,8 +225,8 @@ def _contract_proper_cluster_graph(
"""
# Construct a new graph containing only the edges of maximal weight.
# The components of this graph are the vertices following contraction
max_vertices: set[PcgVertex] = set()
max_edges: dict[PcgVertex, set[PcgVertex]] = {}
max_vertices: PcgVertexSet = set()
max_edges: PcgEdgeMap = {}
for pair, count in taxa_co_occurrences.items():
u, v = pair
max_possible_count = max(taxa_occurrences[u], taxa_occurrences[v])
Expand Down Expand Up @@ -359,8 +361,8 @@ def _generate_induced_trees_with_weights(


def _get_graph_components(
vertices: set[PcgVertex], edges: dict[PcgVertex, set[PcgVertex]]
) -> list[set[PcgVertex]]:
vertices: PcgVertexSet, edges: PcgEdgeMap
) -> list[PcgVertexSet]:
"""
Given a graph expressed as a set of vertices and a dictionary of
edges (mapping vertices to sets of other vertices), find the
Expand All @@ -373,7 +375,7 @@ def _get_graph_components(
Returns:
list[set]: A list of sets of vertices, each element a component.
"""
components: list[set[PcgVertex]] = []
components: list[PcgVertexSet] = []

unexplored = vertices.copy()
while unexplored:
Expand All @@ -392,14 +394,14 @@ def _get_graph_components(


def _proper_cluster_graph_edges(
pcg_vertices: set[PcgVertex],
pcg_vertices: PcgVertexSet,
trees: Sequence[TreeNode],
weights: Sequence[float],
pcg_weighting: str,
normalise_pcg_weights: bool,
depth_normalisation: bool,
) -> tuple[
dict[PcgVertex, set[PcgVertex]],
PcgEdgeMap,
dict[EdgeTuple, float],
dict[PcgVertex, int],
dict[EdgeTuple, int],
Expand All @@ -423,7 +425,7 @@ def _proper_cluster_graph_edges(
Returns:
tuple[dict, dict[Frozenset, float]]: The edges and weights of the edges
"""
edges: dict[PcgVertex, set[PcgVertex]] = {}
edges: PcgEdgeMap = {}
edge_weights: dict[EdgeTuple, float] = {}
taxa_occurrences: dict[PcgVertex, int] = {}
taxa_co_occurrences: dict[EdgeTuple, int] = {} # Number of times a taxa appears
Expand Down Expand Up @@ -480,7 +482,7 @@ def _proper_cluster_graph_edges(


def dfs_pcg_weights(
edges: dict[PcgVertex, set[PcgVertex]],
edges: PcgEdgeMap,
edge_weights: dict[EdgeTuple, float],
taxa_co_occurrences: dict[EdgeTuple, int],
tree: PhyloNode,
Expand Down

0 comments on commit f6126e0

Please sign in to comment.