In [None]:
from importlib import reload
from glob import glob

import dendropy as dp
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from jacksonii_analyses import trees, plotting

reload(trees)
reload(plotting)

In [None]:
pops = pd.read_csv(
    "../data/samples/populations.txt", 
    sep="\t", 
    header=None,
    names=["sample", "population"])
pops.head()

In [None]:
tree = trees.read_tree("../data/phylo/caster_site_br_rerooted.nwk")
tree.is_rooted = True
for species in pops["population"].unique():
    tips = pops.loc[pops["population"] == species, "sample"].tolist()
    is_mono = trees.check_if_clade_is_monophyletic(tree, tips)
    support = trees.get_node_annotation(tree, tips)
    print(f"{species}: {is_mono}: {support}")

In [None]:
genetree_paths = glob("../data/phylo/loci/*.treefile")
genetrees = [trees.read_tree(p) for p in genetree_paths]

In [None]:
df = pd.DataFrame()
for tree_path in genetree_paths:
    locus = tree_path.split("/")[-1].replace(".treefile", "")
    tree = trees.read_tree(tree_path)
    tree.reroot_at_midpoint()
    for species in pops["population"].unique():
        tips = pops.loc[pops["population"] == species, "sample"].tolist()
        is_mono = trees.check_if_clade_is_monophyletic(tree, tips)
        support = trees.get_node_annotation(tree, tips)
        df = pd.concat([df, pd.DataFrame({
            "locus": locus,
            "population": species,
            "is_monophyletic": is_mono,
            "support": support
        }, index=[0])], ignore_index=True)

In [None]:
df["support"] = pd.to_numeric(df["support"], errors="coerce")
df["support"] = df.apply(
    lambda row: 0 if row["is_monophyletic"] == False else row["support"], axis=1
)
df["species"] = df["population"].apply(lambda x: plotting.map_population_names[x])
df

In [None]:
# plot a heatmap of support values (trees by populations)
pivot_df = df.pivot(columns="locus", index="species", values="support")
pop_order = pivot_df.mean(axis=1).sort_values(ascending=False).index
locus_order = pivot_df.mean(axis=0).sort_values(ascending=False).index

plt.figure(figsize=(12, 4))
sns.heatmap(
    pivot_df.loc[pop_order, locus_order],
    cmap="Greys",
    vmin=0,
    vmax=100,
    annot=False,
    cbar_kws={"label": "bootstrap support (%)"},
    yticklabels=True,
    xticklabels=False,
)
plt.tight_layout()
plt.ylabel(None)
plt.savefig("../data/figs/clade_support_heatmap.svg")
plt.show()

In [None]:
pivot_df.apply(
    lambda x: f"{round((x > 70).sum() / pivot_df.shape[1], 4)}, ({(x > 70).sum()}/{pivot_df.shape[1]})", 
    axis=1,
).sort_values(ascending=False)