In [1]:
from src.datasets.load_trees import load_trees_from_file
from pathlib import Path
from Bio.Phylo.BaseTree import Clade
import networkx as nx
from tqdm import tqdm
import numpy as np

In [2]:
TREE_FILE = Path("data/mcmc_runs/yule-10_2.trees")
NUM_TAXA = 10

In [3]:
trees = load_trees_from_file(TREE_FILE)

In [4]:
def is_leaf(name: str) -> bool:
    return type(name) == int

In [5]:
tree_graphs: list[nx.Graph] = []

def construct_tree_graph(graph: nx.Graph, vertex: Clade, is_root: bool, running_internal_node_id: list[int]) -> str | int :
    vertex_name = None
    
    if vertex.is_terminal():
        assert vertex.name
        vertex_name = int(vertex.name)
    elif is_root:
        vertex_name = "root"
    else:
        vertex_name = f"internal_{running_internal_node_id[0]}"

    for child in vertex.clades:
        running_internal_node_id[0] += 1
        child_name = construct_tree_graph(graph, child, False,running_internal_node_id)
        graph.add_edge(vertex_name, child_name)

    return vertex_name

for tree in tqdm(trees):
    graph = nx.Graph()
    running_internal_node_id = [0]
    construct_tree_graph(graph, tree.root, True, running_internal_node_id)
    tree_graphs.append(graph)

100%|██████████| 50001/50001 [00:01<00:00, 44851.39it/s]


In [6]:
pairwise_scores = np.zeros((NUM_TAXA, NUM_TAXA))

for tree_graph in tqdm(tree_graphs):
    for leaf1 in list(tree_graph.nodes):
        if not is_leaf(leaf1):
            continue

        for leaf2 in list(tree_graph.nodes):
            if not is_leaf(leaf2) or leaf1 == leaf2:
                continue
            
            num_internal_nodes = nx.shortest_path_length(tree_graph, leaf1, leaf2) - 1
            score = 2 ** (-num_internal_nodes - 1)
            pairwise_scores[leaf1, leaf2] += score
            pairwise_scores[leaf2, leaf1] += score

pairwise_scores /= len(tree_graphs)

100%|██████████| 50001/50001 [00:17<00:00, 2912.74it/s]


In [7]:
score_graph = nx.Graph()

for i in range(pairwise_scores.shape[0]):
    for j in range(pairwise_scores.shape[1]):
        if i < j:
            score_graph.add_edge(i, j, weight=pairwise_scores[i, j])

In [181]:
# prune score graph

# sorted_edges = sorted(
#     list(score_graph.edges),
#     key=lambda x: score_graph.get_edge_data(*x)["weight"],
# )

# print(f"Num edges before pruning: {len(score_graph.edges)}")

# last_removed_edge = None
# while nx.is_connected(score_graph):
#     last_removed_edge = sorted_edges.pop()
#     score_graph.remove_edge(*last_removed_edge)

# score_graph.add_edge(*last_removed_edge, weight=pairwise_scores[*last_removed_edge])

# print(f"Num edges after pruning: {len(score_graph.edges)}")

In [182]:
def ranking_score(ranking: list[int]) -> float:
    score = 0
    for u, v in zip(ranking, ranking[1:]):
        score += np.log(pairwise_scores[u, v])
    return score

In [183]:
all_hamiltonian_paths = []

for i in tqdm(list(range(NUM_TAXA))):
    for j in range(j, NUM_TAXA):
        all_hamiltonian_paths += [
            path for path in nx.simple_paths.all_simple_paths(score_graph, i, j)
            if len(path) == NUM_TAXA
        ]

100%|██████████| 10/10 [00:03<00:00,  2.64it/s]


In [184]:
max_hamiltonian_path = max(
    all_hamiltonian_paths,
    key=ranking_score
)

In [185]:
max_hamiltonian_path

[1, 3, 2, 5, 8, 0, 6, 7, 4, 9]

In [186]:
def is_subranking_compatible(sub_ranking: list[int], ranking: list[int]) -> bool:
    for i, _ in enumerate(ranking):
        compatible = True
        for j, _ in enumerate(sub_ranking):
            if ranking[(i + j) % len(ranking)] != sub_ranking[j]:
                compatible = False
                break
        if compatible:
            return True
    return False

In [187]:
def is_compatible(ranking, node):
    if node.is_terminal():
        return [[int(node.name)]]
    
    child_1 = node.clades[0]
    child_2 = node.clades[1]

    compatible_subrankings_1 = is_compatible(ranking, child_1)
    compatible_subrankings_2 = is_compatible(ranking, child_2)

    compatible_rankings = []

    for sub_ranking_1 in compatible_subrankings_1:
        for sub_ranking_2 in compatible_subrankings_2:
            if is_subranking_compatible(sub_ranking_1 + sub_ranking_2, ranking):
                compatible_rankings.append(sub_ranking_1 + sub_ranking_2)
            
            if is_subranking_compatible(sub_ranking_2 + sub_ranking_1, ranking):
                compatible_rankings.append(sub_ranking_2 + sub_ranking_1)

            if is_subranking_compatible(sub_ranking_1[::-1] + sub_ranking_2, ranking):
                compatible_rankings.append(sub_ranking_1[::-1] + sub_ranking_2)
            
            if is_subranking_compatible(sub_ranking_2 + sub_ranking_1[::-1], ranking):
                compatible_rankings.append(sub_ranking_2 + sub_ranking_1[::-1])

            if is_subranking_compatible(sub_ranking_1 + sub_ranking_2[::-1], ranking):
                compatible_rankings.append(sub_ranking_1 + sub_ranking_2[::-1])
            
            if is_subranking_compatible(sub_ranking_2[::-1] + sub_ranking_1, ranking):
                compatible_rankings.append(sub_ranking_2[::-1] + sub_ranking_1)

            if is_subranking_compatible(sub_ranking_1[::-1] + sub_ranking_2[::-1], ranking):
                compatible_rankings.append(sub_ranking_1[::-1] + sub_ranking_2[::-1])
            
            if is_subranking_compatible(sub_ranking_2[::-1] + sub_ranking_1[::-1], ranking):
                compatible_rankings.append(sub_ranking_2[::-1] + sub_ranking_1[::-1])

    return compatible_rankings

In [188]:
abs(0 - 9) % 10

9

In [189]:
num_compatible_trees = 0

for i in range(1000):
    if is_compatible(max_hamiltonian_path, trees[i].root):
        num_compatible_trees += 1

num_compatible_trees / 1000

0.757