In [5]:
from src.datasets.load_trees import load_trees_from_file
from pathlib import Path
from collections import defaultdict
from Bio.Phylo.BaseTree import Clade, Tree
import networkx as nx
from tqdm import tqdm
import numpy as np

In [65]:
TREE_FILE = Path("data/mcmc_runs/yule-10_2.trees")
NUM_TAXA = 10

In [66]:
trees = load_trees_from_file(TREE_FILE)

In [68]:
cherries = defaultdict(int)

def collect_cherries(node: Clade):
    if node.clades[0].is_terminal() and node.clades[1].is_terminal():
        cherries[(node.clades[0].name, node.clades[1].name)] += 1

    if not node.clades[0].is_terminal(): 
        collect_cherries(node.clades[0])

    if not node.clades[1].is_terminal(): 
        collect_cherries(node.clades[1])

for tree in trees:
    collect_cherries(tree.root)

In [69]:
cherries

defaultdict(int,
            {('0', '6'): 50001,
             ('4', '9'): 48397,
             ('1', '3'): 49988,
             ('2', '5'): 50000,
             ('4', '7'): 528,
             ('7', '9'): 878})

In [60]:
cherries

defaultdict(int,
            {('0', '7'): 49989,
             ('4', '8'): 12763,
             ('5', '9'): 50001,
             ('1', '6'): 50001,
             ('0', '2'): 4,
             ('2', '7'): 8})

In [199]:
parts = []

for component in components:
    longest_path = max(
        (
            path
            for source in component
            for target in component
            for path in nx.all_simple_paths(G, source, target)
        ),
        key=lambda x: sum(G.get_edge_data(*edge)["w"] for edge in zip(x, x[1:])),
        default=[],
    )
    parts.append(longest_path)

for node in range(10):
    node = str(node)
    
    if node not in G.nodes():
        parts.append([node])

In [200]:
def score_ranking(ranking: list[int], terminals: list[Clade]) -> float:
    total_score = 0

    for a, b in zip(ranking, ranking[1:]):
        score += G.get_path()

[['0', '7', '2'], ['4', '8'], ['9', '5'], ['6', '1'], ['3']]

In [201]:
from random import shuffle

shuffle(parts)

ranking = sum(parts, [])

In [202]:
ranking

['6', '1', '3', '4', '8', '0', '7', '2', '9', '5']

In [203]:
tree = trees[0]

In [204]:
tree

Tree(name='STATE_0', rooted=False, weight=1.0)

In [205]:
def is_compatible(node_to_rank, node):
    if node.is_terminal():
        return {(node.name, node.name)}
    
    child_1 = node.clades[0]
    child_2 = node.clades[1]

    compatible_border_1 = is_compatible(node_to_rank, child_1)
    compatible_border_2 = is_compatible(node_to_rank, child_2)

    compatible_borders = set()

    for border_1_1, border_1_2 in compatible_border_1:
        rank_1_1 = node_to_rank[border_1_1]
        rank_1_2 = node_to_rank[border_1_2]

        for border_2_1, border_2_2 in compatible_border_2:
            rank_2_1 = node_to_rank[border_2_1]
            rank_2_2 = node_to_rank[border_2_2]

            if abs(rank_1_1 - rank_2_1) % 9 == 1:
                compatible_borders.add((border_1_2, border_2_2))

            if abs(rank_1_1 - rank_2_2) % 9 == 1:
                compatible_borders.add((border_1_2, border_2_1))

            if abs(rank_1_2 - rank_2_1) % 9 == 1:
                compatible_borders.add((border_1_1, border_2_2))

            if abs(rank_1_2 - rank_2_2) % 9 == 1:
                compatible_borders.add((border_1_1, border_2_1))
    
    return compatible_borders

In [206]:
for i in range(len(trees)):
    if is_compatible({r: i for i, r in enumerate(ranking)}, trees[i].root) == {}:
        print("HURRAY")

In [207]:
trees[0].get_terminals(order="postorder")

[Clade(branch_length=0.0018137667043779978, name='0'),
 Clade(branch_length=0.0018137667043779978, name='7'),
 Clade(branch_length=0.007344551860021729, name='2'),
 Clade(branch_length=0.03919349863276149, name='3'),
 Clade(branch_length=0.024114672292863383, name='4'),
 Clade(branch_length=0.024114672292863383, name='8'),
 Clade(branch_length=0.012416811531248986, name='5'),
 Clade(branch_length=0.012416811531248986, name='9'),
 Clade(branch_length=0.03704886004558185, name='1'),
 Clade(branch_length=0.03704886004558185, name='6')]

In [210]:
is_compatible(
    {
        "0": 9,
        "7": 8,
        "2": 7,
        "3": 6,
        "4": 5,
        "8": 4,
        "5": 3,
        "9": 2,
        "1": 1,
        "6": 0,
    },
    trees[0].root,
)

{('0', '6')}