In [None]:
import warnings

import subprocess
import dendropy

from ete3 import Tree, TreeStyle, NodeStyle, TextFace
from IPython.display import Image, display

warnings.filterwarnings("ignore", category=UserWarning, module="ete3")

In [None]:
genetrees_path = "../data/phylo/genetrees.nwk"
tips_to_species_mapping_path = "../data/samples/populations.txt"
species_tree_path = "../data/phylo/guide_species_tree.nwk"
caster_output = "../data/phylo/caster_site_br.nwk"
snp_fasta = "../data/phylo/snp_concat.fasta"
caster_output_rooted = "../data/phylo/caster_site_br_rerooted.nwk"
species_tree_rerooted_path = "../data/phylo/guide_species_tree_rerooted.nwk"
pops_path = "../data/samples/populations.txt"
admixed_individuals_table = "../data/var/admixture/admixed_individuals.csv"

In [None]:
subprocess.run([
    "astral4",
    "-i", genetrees_path,
    "--mapping", tips_to_species_mapping_path,
    "-o", species_tree_path,
], check=True)

In [None]:
subprocess.run([
    "caster-site_branchlength",
    "-i", snp_fasta,
    "-o", caster_output,
], check=True)

In [None]:
cmd = subprocess.run([
    "nw_reroot",
    species_tree_path,
    "AmanitaspF11", "AmanitaspT31",
], check=True, capture_output=True)
with open(species_tree_rerooted_path, "w") as f:
    f.write(cmd.stdout.decode("utf-8"))
cmd.stdout.decode("utf-8")

In [None]:
outgroup_string = "SRR30172829 SRR30172830 SRR30172831 SRR30172832 SRR30172788 SRR30172789 SRR30172790 SRR30172792"
outgroup_list = outgroup_string.split(" ")
cmd = subprocess.run([
    "nw_reroot",
    caster_output,
    ] + outgroup_list
, check=True, capture_output=True)
with open(caster_output_rooted, "w") as f:
    f.write(cmd.stdout.decode("utf-8"))
cmd.stdout.decode("utf-8")

In [None]:
tree = dendropy.Tree.get_from_path(
    species_tree_rerooted_path,
    "newick",
)
tree.print_plot()

In [None]:
import pandas as pd
from ete3 import Tree, TreeNode

def expand_tree_polytomies(base_tree: Tree, samples_mapping: dict):

    # Use a list to store nodes to be modified to avoid issues with
    # modifying the tree while iterating over it.
    species_leaf_to_expand = []
    for node in base_tree.traverse("postorder"):
        # We only want to expand leaf nodes that are species
        if node.is_leaf() and node.name in samples_mapping:
            species_leaf_to_expand.append(node)

    # Expand each species node into a polytomy of samples
    for species_node in species_leaf_to_expand:
        species_name = species_node.name
        parent = species_node.up
        
        # Create a new node for the species to act as the parent for the polytomy
        # This keeps the original structure intact, just adds a new level
        new_species_clade = parent.add_child(name=species_name)
        
        # Add a new child node for each sample, creating a polytomy
        for sample_name in samples_mapping[species_name]:
            new_species_clade.add_child(name=sample_name)
        
        # Remove the original species node from its parent
        species_node.detach()

    # If the root itself was a species, it would be replaced. The above logic handles it correctly
    # as `species_node.detach()` on the root will not fail.
    return base_tree.copy()

In [None]:
tree = Tree(species_tree_rerooted_path)
pops = pd.read_table(pops_path, header=None, names=["sample", "population"])
pops_admixed = pd.read_csv(admixed_individuals_table)
pops = pops[~pops["sample"].isin(pops_admixed["sample"])]
pops_dict = pops.groupby('population')['sample'].apply(list).to_dict()
pops

In [None]:
expanded_tree = expand_tree_polytomies(tree, pops_dict)

In [None]:
expanded_tree.write(outfile="../data/phylo/guide_species_tree_rerooted_expanded.nwk", format=9)
expanded_tree.write(format=9)