In [1]:
import time
import numpy as np
import os
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Any
import itertools
import json
from typing import Optional
import hashlib

from ete3.coretype.tree import TreeError
from scipy.stats import mannwhitneyu, kstest
import pathlib

import bdms
from bdms import mutators, poisson
import ete3
import modulators
import my_bdms
import utils

import traceback

np.seterr(divide='ignore'); # the modulator class and poisson class will throw divide by zero warnings. It is correct to have these evaluation to np.inf, so warnings are suppressed.

In [2]:
class NumpyArrayEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

In [3]:
birth_process = poisson.DiscreteProcess([1.0, 2.0])
death_process = poisson.ConstantProcess(1.0)
mutation_process = poisson.ConstantProcess(1.0)
mutator = mutators.DiscreteMutator((0, 1), np.array([[0, 1], [1, 0]]))

rng = np.random.default_rng(seed=0)
tree = bdms.TreeNode(
                t=0.0,
                dist=0,
                name='root',
            )
tree.state = 0

start_node = tree.evolve(
    1,
    birth_process=birth_process,
    death_process=death_process,
    mutation_process=mutation_process,
    mutator=mutator,
    seed=rng,
)

## Define simulator

In [4]:
def simulate_tree(
        seed: int,
        state_space: my_bdms.ListOfHashables,
        sampling_probability: float,
        init_dist: np.ndarray,
        birth_process: poisson.Process,
        death_process: poisson.Process,
        mutation_process: poisson.Process,
        mutator: mutators.Mutator,
        t_max: float,
        tree_id: Optional[int] = None,
        t_min: float = 0.0,
        do_prune: bool = True,
        min_survivors: int = 1,
        capacity: int = 10000,
):
    if tree_id is None:
        tree_id = 1

    rng = np.random.default_rng(seed=seed)
    try:
        start_time = time.time()
        tree = bdms.TreeNode()
        tree.t = t_min
        tree.state = rng.choice(state_space,p=init_dist)
        tree.evolve(
            t_max,
            birth_process=birth_process,
            death_process=death_process,
            mutation_process=mutation_process,
            mutator=mutator,
            min_survivors=min_survivors,
            capacity=capacity,
            birth_mutations=False,
            seed=rng,
        )
        unpruned_tree_size = utils.tree_size(tree)
        unpruned_tree_leaf_count = len(tree)

        if do_prune:
            tree.sample_survivors(p=sampling_probability) ## WARNING: does not do phenotype specific sampling.
            if sum([node.event == "sampling" for node in tree.traverse()]) > 0:
                tree.prune()
            else:
                tree = None 
        
        end_time = time.time()
        is_error = False        

    except TreeError as e:
        tree = None
        unpruned_tree_size = 0
        unpruned_tree_leaf_count = 0
        end_time = time.time()
        is_error = True


    return {"tree": tree, 
            "tree_id": tree_id,
            "seed": seed,
            "start_time": start_time, 
            "end_time": end_time, 
            "unpruned_tree_size": unpruned_tree_size,
            "unpruned_tree_leaf_count": unpruned_tree_leaf_count,
            "is_error": is_error}



In [5]:
def compute_metrics(
    tree_info: dict[str,Any],
    tree_metric_info: list[str],
    clade_metric_info: list[str],
):
    
    tree_df = pd.DataFrame()
    clade_df = pd.DataFrame()

    seed = tree_info["seed"]
    tree_id = tree_info["tree_id"]
    tree = tree_info["tree"]
    tree_df.loc[tree_id,"tree_id"] = tree_info["tree_id"]
    tree_df.loc[tree_id,"seed"] = tree_info["seed"]
    tree_df.loc[tree_id,"unpruned tree size"] = tree_info["unpruned_tree_size"]
    tree_df.loc[tree_id,"unpruned tree leaf count"] = tree_info["unpruned_tree_leaf_count"]
    tree_df.loc[tree_id,"pruned tree size"] = utils.tree_size(tree) if tree is not None else 0
    tree_df.loc[tree_id,"pruned tree leaf count"] = len(tree) if tree is not None else 0
    tree_df.loc[tree_id,"time"] = tree_info["end_time"] - tree_info["start_time"]
    tree_df.loc[tree_id,"is_error"] = int(tree_info["is_error"])

    if tree is not None:
        
        tree = tree_info["tree"]
        
        for _, info in tree_metric_info.items(): # iterate over metrics
            for fnc_arg in info["fnc_args"]:
                data = info["fnc"](tree,**fnc_arg)
                for key in info["result_keys"]:
                    tree_df.loc[tree_id,info["col_name"](**fnc_arg,**{info["result_key_name"]:key})] = data[key] if key in data else info["default_value"]

        for _, info in clade_metric_info.items(): # iterate over metrics
            for fnc_arg in info["fnc_args"]:
                data = info["fnc"](tree,**fnc_arg)
                clade_df = pd.concat([clade_df,
                                        pd.DataFrame({"tree_id": [tree_id]*len(data), 
                                                    "metric": [info["metric_name"]]*len(data), 
                                                    "phenotype": [clade["phenotype"] for clade in data],
                                                    "value": [clade["value"] for clade in data]})
                                                    ],
                                                    ignore_index=True)
        

    return tree_df, clade_df

## Run Simulations

In [29]:
single_type_no_death_specs = [{
    "state_space": [1],
    "sampling_probability": rho,
    "init_dist": np.array([1.0]),
    "birth_rates": np.array([1.0]),
    "death_rates": np.array([0.0]),
    "mutation_rates": np.array([0.0]),
    "transition_matrix": np.array([[1.0]]),
    "t_min": 0.0,
    "t_max": 5.0,
    "mode": mode,
    "min_survivors": 1,
    "capacity": 10000,
    "dt": .01,
    "tag": f"single_type_no_death_rho{int(10*rho)}",
    "directory": "data/single_type_no_death",
    "N_trees": 1000
} for rho, mode in itertools.product(np.linspace(.1,1.0,10),["full","FE"])]
# } for rho, mode in [(.5,"FE")]]

single_type_w_death_specs = [{
    "state_space": [1],
    "sampling_probability": rho,
    "init_dist": np.array([1.0]),
    "birth_rates": np.array([1.0]),
    "death_rates": np.array([1.0]),
    "mutation_rates": np.array([0.0]),
    "transition_matrix": np.array([[1.0]]),
    "t_min": 0.0,
    "t_max": 10,
    "mode": mode,
    "min_survivors": 1,
    "capacity": 10000,
    "dt": .01,
    "tag": f"single_type_w_death_rho{int(10*rho)}",
    "directory": "data/single_type_w_death",
    "N_trees": 50
} for rho, mode in itertools.product(np.linspace(.1,1.0,10),["full","FE"])]

single_type_w_death_huge_specs = [{
    "state_space": [1],
    "sampling_probability": 1e-9,
    "init_dist": np.array([1.0]),
    "birth_rates": np.array([1.0]),
    "death_rates": np.array([.5]),
    "mutation_rates": np.array([0.0]),
    "transition_matrix": np.array([[1.0]]),
    "t_min": 0.0,
    "t_max": 50,
    "mode": "FE",
    "min_survivors": 1,
    "capacity": 10000,
    "dt": .01,
    "tag": f"single_type_w_death_huge",
    "directory": "data/single_type_w_death_huge",
    "N_trees": 50
}]

multitype_high_birth_fitness_specs = [{
    "state_space": [1,2],
    "sampling_probability": rho,
    "init_dist": np.array([1.0,0.0]),
    "birth_rates": np.array([.25,1.0]),
    "death_rates": np.array([.25,.25]),
    "mutation_rates": np.array([.1,.8]),
    "transition_matrix": np.array([[0.0,1.0],[1.0,0.0]]),
    "t_min": 0.0,
    "t_max": 20.0,
    "mode": mode,
    "min_survivors": 1,
    "capacity": 15000,
    "dt": .01,
    "tag": f"multitype_high_birth_fitness_rho{int(10*rho)}",
    "directory": "data/multitype_high_birth_fitness",
    "N_trees": 50
} for rho, mode in itertools.product(np.linspace(.1,1.0,5),["full","FE"])]

multitype_high_birth_fitness_huge_specs = [{
    "state_space": [1,2],
    "sampling_probability": 1e-9,
    "init_dist": np.array([0.0,1.0]),
    "birth_rates": np.array([.25,1.0]),
    "death_rates": np.array([.5,.25]),
    "mutation_rates": np.array([.1,.25]),
    "transition_matrix": np.array([[0.0,1.0],[1.0,0.0]]),
    "t_min": 0.0,
    "t_max": 47.0,
    "mode": "FE",
    "min_survivors": 1,
    "capacity": 15000,
    "dt": .01,
    "tag": f"multitype_high_birth_fitness_huge",
    "directory": "data/multitype_high_birth_fitness_huge",
    "N_trees": 3
}]

multitype_high_birth_fitness_1000_specs = [{
    "state_space": [1,2],
    "sampling_probability": rho,
    "init_dist": np.array([1.0,0.0]),
    "birth_rates": np.array([.25,1.0]),
    "death_rates": np.array([.25,.25]),
    "mutation_rates": np.array([.1,.8]),
    "transition_matrix": np.array([[0.0,1.0],[1.0,0.0]]),
    "t_min": 0.0,
    "t_max": 20.0,
    "mode": mode,
    "min_survivors": 1,
    "capacity": 15000,
    "dt": .01,
    "tag": f"multitype_high_birth_fitness_1000_rho{int(10*rho)}",
    "directory": "data/multitype_high_birth_fitness_1000",
    "N_trees": 1000
} for rho, mode in itertools.product([.5],["full","FE"])]

sim_specs = (
    single_type_no_death_specs
    + single_type_w_death_specs
    + single_type_w_death_huge_specs
    + multitype_high_birth_fitness_specs
    + multitype_high_birth_fitness_1000_specs
    + multitype_high_birth_fitness_huge_specs
)
sim_specs = single_type_w_death_specs
# sim_specs = multitype_high_birth_fitness_1000_specs
# sim_specs = multitype_high_birth_fitness_huge_specs
# sim_specs = single_type_no_death_specs

In [30]:
tree_ids = list(range(10000))
verbose = True
save_to_file = True

for sim_spec in sim_specs:

    # Initialize data frames
    all_trees_df = pd.DataFrame()
    all_clades_df = pd.DataFrame()

    # Set up simulators
    if sim_spec["mode"] == "full":

        do_prune = True
        birth_process = my_bdms.DiscreteProcess(sim_spec["state_space"],sim_spec["birth_rates"])
        death_process = my_bdms.DiscreteProcess(sim_spec["state_space"],sim_spec["death_rates"])
        mutation_process = my_bdms.DiscreteProcess(sim_spec["state_space"],sim_spec["mutation_rates"])
        mutator = mutators.DiscreteMutator(state_space=sim_spec["state_space"],transition_matrix=sim_spec["transition_matrix"])

    elif sim_spec["mode"] == "FE":

        do_prune = False
        modulator = modulators.FEModulator(
            state_space = sim_spec["state_space"],
            birth_rates = sim_spec["birth_rates"],
            death_rates = sim_spec["death_rates"],
            mutation_rates = sim_spec["mutation_rates"],
            transition_matrix = sim_spec["transition_matrix"],
            rhos = np.full(len(sim_spec["state_space"]),sim_spec["sampling_probability"]),
            t_min = sim_spec["t_min"],
            t_max = sim_spec["t_max"],
            dt = sim_spec["dt"],
        )
        birth_process = my_bdms.CustomProcess(modulator.λ,modulator.Λ,modulator.Λ_inv)
        death_process = poisson.ConstantProcess(0.0)
        mutation_process = my_bdms.CustomProcess(modulator.m,modulator.M,modulator.M_inv) # Divide by zero issue?
        mutator = my_bdms.CustomMutator(modulator)
    
    # Set up metrics we will compute
    tree_metric_info = {
        "subtree_counts": {
            "fnc": utils.num_subtrees_by_size,
            "fnc_args": [{}],
            "result_keys": list(range(1,11)),
            "result_key_name": "subtree_size",
            "default_value": 0,
            "col_name": lambda subtree_size: f"subtree_count_w_size={subtree_size}"
        },
        "total_branch_length": {
            "fnc": utils.total_branch_length,
            "fnc_args": [{}],
            "result_keys": ["total"],
            "result_key_name": "total",
            "default_value": 0,
            "col_name": lambda total: "total_branch_length"
        },
        "num_lineages_by_phenotype_and_time": {
            "fnc": utils.num_lineages_by_phenotype_and_time,
            "fnc_args": [{"t": t} for t in np.linspace(sim_spec["t_min"],sim_spec["t_max"],num=6)[1:-1]],
            "result_keys": sim_spec["state_space"],
            "result_key_name": "phenotype",
            "default_value": 0,
            "col_name": lambda phenotype,t: f"num_lineages_phenotype={phenotype}_t={t}"
        },
        "num_leaves_by_phenotype": {
            "fnc": utils.num_leaves_by_phenotype,
            "fnc_args": [{}],
            "result_keys": sim_spec["state_space"],
            "result_key_name": "phenotype",
            "default_value": 0,
            "col_name": lambda phenotype: f"num_leaves_phenotype={phenotype}"
        },
        "branch_length_by_phenotype": {
            "fnc": utils.branch_length_by_phenotype,
            "fnc_args": [{}],
            "result_keys": [phenotype for phenotype in sim_spec["state_space"]],
            "result_key_name": "phenotype",
            "default_value": 0,
            "col_name": lambda phenotype: f"branch_length_phenotype={phenotype}"
        }
    }

    clade_metric_info = {
        "num_nodes": {
            "fnc": utils.clade_sizes_by_phenotype,
            "fnc_args": [{}],
            "phenotypes": sim_spec["state_space"],
            "metric_name": "number of nodes"
        },
        "branch_length": {
            "fnc": utils.clade_lengths_by_phenotype,
            "fnc_args": [{}],
            "phenotypes": sim_spec["state_space"],
            "metric_name": "branch length"
        }
    }

    # Run simulations
    nonempty_tree_count = 0
    for tree_id in tree_ids:

        # This is so that we aren't using the same seeds in the 
        # across modes, which could violate independence assumptions in
        # distributional comparisons.
        seed = abs(int(hashlib.sha256(f'{sim_spec["mode"]}_{tree_id}'.encode()).hexdigest(),16))
        
        # Simulate tree
        tree_info = simulate_tree(
            seed = seed,
            state_space = sim_spec["state_space"],
            sampling_probability = sim_spec["sampling_probability"],
            init_dist = sim_spec["init_dist"],
            birth_process = birth_process,
            death_process = death_process,
            mutation_process = mutation_process,
            mutator = mutator,
            t_max = sim_spec["t_max"],
            t_min = sim_spec["t_min"],
            tree_id = tree_id,
            do_prune = do_prune,
            min_survivors = sim_spec["min_survivors"],
            capacity = sim_spec["capacity"],
        )

        tree_df, clade_df = compute_metrics(
            tree_info = tree_info,
            tree_metric_info = tree_metric_info,
            clade_metric_info = clade_metric_info,
        )
        all_trees_df = pd.concat([all_trees_df,tree_df],ignore_index=True)
        all_clades_df = pd.concat([all_clades_df,clade_df],ignore_index=True)

        if tree_info["tree"] is not None:
            nonempty_tree_count += 1

        if verbose:
            print(f'{sim_spec["tag"]}, tree_id {tree_id + 1}, nonempty tree count: {nonempty_tree_count}/{sim_spec["N_trees"]}, tip count: {tree_df.loc[tree_id,"pruned tree leaf count"]}, tree size: {tree_df.loc[tree_id,"pruned tree size"]}',end='\r')

        if nonempty_tree_count >= sim_spec["N_trees"]:
            break

    # Save to file
    if save_to_file:

        directory_path = pathlib.Path(sim_spec["directory"])

        counter = 1
        tree_path = directory_path/f'{sim_spec["mode"]}'/f'tree_metrics_{sim_spec["tag"]}_v{counter}_df.csv'
        clade_path = directory_path/f'{sim_spec["mode"]}'/f'clade_metrics_{sim_spec["tag"]}_v{counter}_df.csv'
        
        while tree_path.exists() or clade_path.exists():
            counter += 1
            tree_path = directory_path/f'{sim_spec["mode"]}'/f'tree_metrics_{sim_spec["tag"]}_v{counter}_df.csv'
            clade_path = directory_path/f'{sim_spec["mode"]}'/f'clade_metrics_{sim_spec["tag"]}_v{counter}_df.csv'

        tree_path.parent.mkdir(parents=True, exist_ok=True)
        clade_path.parent.mkdir(parents=True, exist_ok=True)

        with open(directory_path/f'{sim_spec["tag"]}_v{counter}_specs.json', 'w') as outfile:
            json.dump(sim_spec, outfile, cls=NumpyArrayEncoder)

        all_trees_df.to_csv(tree_path,index=False)
        all_clades_df.to_csv(clade_path,index=False)
    
    # print(f'{sim_spec["tag"]}, seed {seed + 1}, nonempty tree count: {nonempty_tree_count}/{sim_spec["N_trees"]}, tip count: {tree_df.loc[seed,"pruned tree leaf count"]}, tree size: {tree_df[seed,"pruned tree size"]}',end='\r')

single_type_w_death_rho10, tree_id 50, nonempty tree count: 50/50, tip count: 31.0, tree size: 62.000

## Draw tree

In [26]:
tree_info["tree"]

Tree node '0' (0x17d4975d)

In [28]:
tree = tree_info["tree"]
max_t = max(node.t for node in tree.traverse())
color_by = "state"
n_leaves = len(tree)

w = 6
h = 1.5
dpi = 600
my_scale = dpi * h / (max_t - tree.t)

spacing_to_linewidth = 6.0
linewidth = int(min( w / (n_leaves + spacing_to_linewidth * (n_leaves - 1)) * dpi , .0125 * dpi ))
branch_spacing = ( w * dpi - n_leaves * linewidth ) / (n_leaves - 1)
tree_width = linewidth * n_leaves + branch_spacing * (n_leaves - 1)

cmap = "coolwarm_r"
cmap = mpl.colormaps[cmap]
minrange = min(getattr(node, color_by) for node in tree.traverse())
maxrange = max(getattr(node, color_by) for node in tree.traverse())
halfrange = (maxrange - minrange)/2
vcenter = (maxrange + minrange)/2
norm = mpl.colors.CenteredNorm(
    vcenter=vcenter,
    halfrange=halfrange if halfrange > 0 else 1,
)
colormap = {
            node.name: mpl.colors.to_hex(cmap(norm(getattr(node, color_by))))
            for node in tree.traverse()
        }

for node in tree_info["tree"].traverse():
    nstyle = ete3.NodeStyle()
    nstyle["hz_line_width"] = linewidth  # Horizontal line width
    nstyle["vt_line_width"] = linewidth  # Vertical line width
    nstyle["size"] = 0.0
    nstyle["hz_line_color"] = colormap[node.name]
    nstyle["vt_line_color"] = colormap[node.name]
    nstyle["draw_descendants"] = True
    node.set_style(nstyle)

tree_style = ete3.TreeStyle()
tree_style.show_leaf_name = False
tree_style.show_scale = False
tree_style.min_leaf_separation = branch_spacing
tree_style.rotation = 90
tree_style.tree_width = tree_width
tree_style.margin_left = 0
tree_style.margin_right = 0
tree_style.margin_top = 0
tree_style.margin_bottom = 0
tree_style.scale = my_scale
tree_style.allow_face_overlap = True

ete3.Tree.render(tree,file_name = "fig/multitype_huge.pdf", units = "in", dpi = dpi, tree_style = tree_style);

# Make legend
colors = [mpl.colors.to_hex(cmap(norm(x))) for x in [1,2]]
labels = ['Low fitness', 'High fitness']

handles = [mpl.lines.Line2D([0], [0], marker='s', color='w', label=label,
                  markerfacecolor=color, markersize=10) for color, label in zip(colors, labels)]

fig = plt.figure(figsize=(3, 3))
fig.legend(handles=handles, labels=labels, loc='center', ncol = 2)
fig.savefig("fig/multitype_huge_legend.pdf", bbox_inches='tight')


<Figure size 300x300 with 0 Axes>

In [17]:
linewidth

7.5