In [1]:
from src.datasets.load_trees import load_trees_from_file
from pathlib import Path
from Bio.Phylo.BaseTree import Clade
import networkx as nx
from tqdm import tqdm
import numpy as np
import logging
import pandas as pd
from random import sample

In [2]:
logging.getLogger().setLevel(logging.INFO)

In [3]:
TREE_FILE = Path("data/mcmc_runs/yule-10_140.trees")
NUM_TAXA = 10

In [65]:
def is_leaf(name: str) -> bool:
    return type(name) == int

In [4]:
def construct_tree_graph(
    graph: nx.Graph, vertex: Clade, is_root: bool, running_internal_node_id: list[int]
) -> str | int:
    vertex_name = None

    if vertex.is_terminal():
        assert vertex.name
        vertex_name = int(vertex.name)
    elif is_root:
        vertex_name = "root"
    else:
        vertex_name = f"internal_{running_internal_node_id[0]}"

    for child in vertex.clades:
        running_internal_node_id[0] += 1
        child_name = construct_tree_graph(graph, child, False, running_internal_node_id)
        graph.add_edge(vertex_name, child_name)

    return vertex_name

In [5]:
def ranking_score(pairwise_scores: np.ndarray, ranking: list[int]) -> float:
    score = 0
    for u, v in zip(ranking, ranking[1:]):
        score += np.log(pairwise_scores[u, v])
    return score

In [6]:
def is_subranking_compatible(sub_ranking: list[int], ranking: list[int]) -> bool:
    for i, _ in enumerate(ranking):
        compatible = True
        for j, _ in enumerate(sub_ranking):
            if i + j >= len(ranking):
                compatible = False
                break

            if ranking[i + j] != sub_ranking[j]:
                compatible = False
                break
        if compatible:
            return True
    return False

In [7]:
import itertools

def is_compatible(ranking, node):
    if node.is_terminal():
        return [[int(node.name)]]
    
    child_1 = node.clades[0]
    child_2 = node.clades[1]

    compatible_subrankings_1 = is_compatible(ranking, child_1)
    compatible_subrankings_2 = is_compatible(ranking, child_2)

    compatible_rankings = []

    for sub_ranking_1 in compatible_subrankings_1:
        for sub_ranking_2 in compatible_subrankings_2:
            if is_subranking_compatible(sub_ranking_1 + sub_ranking_2, ranking):
                compatible_rankings.append(sub_ranking_1 + sub_ranking_2)
            
            if is_subranking_compatible(sub_ranking_2 + sub_ranking_1, ranking):
                compatible_rankings.append(sub_ranking_2 + sub_ranking_1)

            if is_subranking_compatible(sub_ranking_1[::-1] + sub_ranking_2, ranking):
                compatible_rankings.append(sub_ranking_1[::-1] + sub_ranking_2)
            
            if is_subranking_compatible(sub_ranking_2 + sub_ranking_1[::-1], ranking):
                compatible_rankings.append(sub_ranking_2 + sub_ranking_1[::-1])

            if is_subranking_compatible(sub_ranking_1 + sub_ranking_2[::-1], ranking):
                compatible_rankings.append(sub_ranking_1 + sub_ranking_2[::-1])
            
            if is_subranking_compatible(sub_ranking_2[::-1] + sub_ranking_1, ranking):
                compatible_rankings.append(sub_ranking_2[::-1] + sub_ranking_1)

            if is_subranking_compatible(sub_ranking_1[::-1] + sub_ranking_2[::-1], ranking):
                compatible_rankings.append(sub_ranking_1[::-1] + sub_ranking_2[::-1])
            
            if is_subranking_compatible(sub_ranking_2[::-1] + sub_ranking_1[::-1], ranking):
                compatible_rankings.append(sub_ranking_2[::-1] + sub_ranking_1[::-1])

    compatible_rankings.sort()
    compatible_rankings = list(k for k,_ in itertools.groupby(compatible_rankings))
    
    return compatible_rankings

In [91]:
import logging
import itertools


def get_coverage(file: Path) -> float:
    logging.info(f"Load trees from {file}...")
    
    trees = load_trees_from_file(TREE_FILE)
    trees = sample(trees, k=10000)

    logging.info("Build tree graph...")

    tree_graphs: list[nx.Graph] = []

    for tree in tqdm(trees):
        graph = nx.Graph()
        running_internal_node_id = [0]
        construct_tree_graph(graph, tree.root, True, running_internal_node_id)
        tree_graphs.append(graph)

    logging.info("Calculate pairwise scores...")

    pairwise_scores = np.zeros((NUM_TAXA, NUM_TAXA))

    for tree_graph in tqdm(tree_graphs):
        for leaf1 in list(tree_graph.nodes):
            if not is_leaf(leaf1):
                continue

            for leaf2 in list(tree_graph.nodes):
                if not is_leaf(leaf2) or leaf1 == leaf2:
                    continue
                
                num_internal_nodes = nx.shortest_path_length(tree_graph, leaf1, leaf2) - 1
                score = 2 ** (-num_internal_nodes - 1)
                pairwise_scores[leaf1, leaf2] += score
                pairwise_scores[leaf2, leaf1] += score

    pairwise_scores /= len(tree_graphs)

    logging.info("Build score graph...")

    score_graph = nx.Graph()

    for i in range(pairwise_scores.shape[0]):
        for j in range(pairwise_scores.shape[1]):
            if i < j:
                score_graph.add_edge(i, j, weight=pairwise_scores[i, j])

    logging.info("Find hamiltonian paths...")

    all_hamiltonian_paths = []

    for i in tqdm(list(range(NUM_TAXA))):
        for j in range(NUM_TAXA):
            all_hamiltonian_paths += [
                path for path in nx.simple_paths.all_simple_paths(score_graph, i, j)
                if len(path) == NUM_TAXA
            ]

    all_hamiltonian_paths.sort()
    all_hamiltonian_paths = list(k for k,_ in itertools.groupby(all_hamiltonian_paths))

    logging.info("Find max hamiltonian path...")

    max_hamiltonian_path = max(
        all_hamiltonian_paths,
        key=lambda ranking: ranking_score(pairwise_scores, ranking)
    )

    logging.info("Calculate coverage...")

    num_compatible_trees = 0

    for tree in tqdm(trees):
        for t in range(NUM_TAXA):
            if is_compatible(max_hamiltonian_path[t:] + max_hamiltonian_path[:t], tree.root):
                num_compatible_trees += 1
                break


    return num_compatible_trees / len(trees)

In [92]:
entropy_paths = Path("data/entropy_data").glob("*.log")

df_dict = {
    "dataset": [],
    "coverage": [],
    "entropy": [],
}

for i, entropy_path in enumerate(list(entropy_paths)[:10]):
    logging.info(f"Process file nr {i}")

    dataset_name = entropy_path.name.removesuffix(".log")
    
    with open(entropy_path, "r") as handle:
        entropy = float(handle.readline().split()[-1])
    
    coverage = get_coverage(Path("data/entropy_data") / f"{dataset_name}.trees")

    df_dict["dataset"].append(dataset_name)
    df_dict["coverage"].append(coverage)
    df_dict["entropy"].append(entropy)

df = pd.DataFrame(df_dict)

INFO:root:Process file nr 0
INFO:root:Load trees from data/entropy_data/yule-10_182_entropy.trees...
INFO:root:Build tree graph...
100%|██████████| 10000/10000 [00:00<00:00, 64372.07it/s]
INFO:root:Calculate pairwise scores...
100%|██████████| 10000/10000 [00:03<00:00, 2534.77it/s]
INFO:root:Build score graph...
INFO:root:Find hamiltonian paths...
100%|██████████| 10/10 [00:38<00:00,  3.86s/it]
INFO:root:Find max hamiltonian path...
INFO:root:Calculate coverage...
100%|██████████| 10000/10000 [00:03<00:00, 3297.25it/s]
INFO:root:Process file nr 1
INFO:root:Load trees from data/entropy_data/yule-10_36_entropy.trees...
INFO:root:Build tree graph...
100%|██████████| 10000/10000 [00:00<00:00, 64113.97it/s]
INFO:root:Calculate pairwise scores...
100%|██████████| 10000/10000 [00:03<00:00, 2516.14it/s]
INFO:root:Build score graph...
INFO:root:Find hamiltonian paths...
100%|██████████| 10/10 [00:39<00:00,  3.97s/it]
INFO:root:Find max hamiltonian path...
INFO:root:Calculate coverage...
100%|██

In [93]:
df

Unnamed: 0,dataset,coverage,entropy
0,yule-10_182_entropy,0.615,5.177165
1,yule-10_36_entropy,0.6293,1.943433
2,yule-10_224_entropy,0.6324,5.669321
3,yule-10_68_entropy,0.6184,3.669731
4,yule-10_121_entropy,0.6216,3.592645
5,yule-10_95_entropy,0.6173,0.724975
6,yule-10_144_entropy,0.6226,3.306
7,yule-10_241_entropy,0.6228,3.003408
8,yule-10_53_entropy,0.6214,1.800727
9,yule-10_106_entropy,0.6221,3.269502


In [8]:
tree = load_trees_from_file(TREE_FILE)

In [9]:
tree[0].

Tree(name='STATE_0', rooted=False, weight=1.0)

In [11]:
is_compatible([4, 3, 9, 7, 5, 6, 8, 0, 2, 1], tree[0].root)

[[4, 3, 9, 7, 5, 6, 8, 0, 2, 1]]

In [13]:
len(
    [1 for t in tree if is_compatible([4, 3, 9, 7, 5, 6, 8, 0, 2, 1], t.root)]
)

17631

In [14]:
len(tree)

50001