In [57]:
import seaborn as sns
import matplotlib.pyplot as plt
from src.distribution_analysis.process_tree import get_observed_nodes

import seaborn as sns
from src.distribution_analysis.process_tree import get_observed_nodes, get_clade_split_df
from src.datasets.load_trees import load_trees, TreeDataset
from src.utils.tree_utils import get_taxa_names

import pandas as pd
from tqdm import tqdm

from src.distribution_analysis.clade import ObservedCladeSplit

In [58]:
import warnings
warnings.filterwarnings('ignore')

sns.set_style("whitegrid")

In [59]:
DATA_SET = TreeDataset.YULE_10
DATA_SET_NAME = DATA_SET.value

### Load trees

In [61]:
trees = load_trees(DATA_SET, max_files=2)

2it [00:05,  2.61s/it]


In [62]:
taxa_names = get_taxa_names(trees[0])
observed_nodes, observed_clade_splits = get_observed_nodes(trees, taxa_names)
df_clade_splits = get_clade_split_df(observed_clade_splits)

35001it [00:01, 28541.77it/s]
100%|██████████| 315009/315009 [00:00<00:00, 1406793.43it/s]


In [63]:
df_dict = {
    "clade_split": [],
    "min_branch_0": [],
    "min_branch_1": [],
}

for clade_split in tqdm(observed_clade_splits):
    min_branch_0 = clade_split.min_branch_length

    min_branch_1 = 0.0
    min_branch_2 = 0.0

    if isinstance(clade_split.min_branch_clade, ObservedCladeSplit):
        min_branch_1 = clade_split.min_branch_clade.min_branch_length

        if isinstance(clade_split.min_branch_clade.min_branch_clade, ObservedCladeSplit):
            min_branch_2 = clade_split.min_branch_clade.min_branch_clade.min_branch_length


    df_dict["clade_split"].append(clade_split.bitstring)
    df_dict["min_branch_0"].append(min_branch_0)
    df_dict["min_branch_1"].append(min_branch_1)

df_extended_branches = pd.DataFrame(df_dict)

100%|██████████| 315009/315009 [00:00<00:00, 1154809.98it/s]


In [64]:
dict_branches_per_split = dict(tuple(df_extended_branches.groupby("clade_split")))
dict_branches_per_split = dict(
    item for item in dict_branches_per_split.items() if len(item[1]) > 1
)

In [65]:
from scipy.optimize import minimize
import numpy as np

In [69]:
def mle(x):
    num_splits = len(dict_branches_per_split)

    mle = 0.0
    for i, df in enumerate(dict_branches_per_split.values()):
        mu = x[i]
        sigma = x[num_splits + i]
        beta = x[-1]

        for _, row in df.iterrows():
            b = np.log(row["min_branch_0"])
            bDown = np.log(row["min_branch_1"]) if row["min_branch_1"] else 0.0
            
            mle += 0.5 * (-np.log(sigma) - (b - mu - beta * bDown)**2 / sigma)

    return mle


def mle_gradient(x):
    num_splits = len(dict_branches_per_split)
    grad = np.zeros(len(x))

    for i, df in enumerate(dict_branches_per_split.values()):
        mu = x[i]
        sigma = x[num_splits + i]
        beta = x[-1]

        for _, row in df.iterrows():
            b = np.log(row["min_branch_0"])
            bDown = np.log(row["min_branch_1"]) if row["min_branch_1"] else 0.0

            grad[i] += (b - mu - beta * bDown) / sigma
            grad[num_splits + i] += 0.5 * (
                (-1 / sigma) + (b - mu - beta * bDown)**2 / np.pow(sigma, 2)
            )
            grad[-1] += (b - mu - beta * bDown) * bDown / sigma

    return grad

def get_initial():
    num_splits = len(dict_branches_per_split)
    x0 = np.zeros(2*num_splits + 1)

    for i, df in enumerate(dict_branches_per_split.values()):
        mu = np.mean(df["min_branch_0"])
        sigma = np.std(df["min_branch_0"])
        x0[i] = mu
        x0[num_splits + i] = sigma**2

    return x0

minimize(mle, x0=get_initial(), jac=mle_gradient)

KeyboardInterrupt: 

In [105]:
x0 = get_initial()
print(mle(x0))

x0[-1] = 0.5
print(mle(x0))

x0[-1] = 1.065
print(mle(x0))

x0[-1] = 2
print(mle(x0))

-754527744546.4204
-453943666811.4805
-335939808372.32196
-657295007837.3887


In [86]:
def get_beta(x):
    num_splits = len(dict_branches_per_split)

    nominator = 0
    denominator = 0

    for i, df in enumerate(dict_branches_per_split.values()):
        mu = x[i]
        sigma = x[num_splits + i]

        for _, row in df.iterrows():
            b = np.log(row["min_branch_0"])
            bDown = np.log(row["min_branch_1"]) if row["min_branch_1"] else 0.0

            nominator += (b - mu) * bDown / sigma
            denominator += bDown**2 / sigma

    print(nominator, denominator)

    return nominator / denominator

In [87]:
get_beta(get_initial())

785352084508.3461 736735716153.8264


np.float64(1.0659888848722094)

In [96]:
x = get_initial()
num_splits = len(dict_branches_per_split)
a = 0
d = 0

for i, df in enumerate(dict_branches_per_split.values()):
    mu = x[i]
    sigma = x[num_splits + i]

    for _, row in df.iterrows():
        b = np.log(row["min_branch_0"])
        bDown = np.log(row["min_branch_1"]) if row["min_branch_1"] else 0.0

        a += (b - mu - 1.0659888848722094 * bDown) * bDown / sigma

a

np.float64(-0.009890630841255188)