# Schematic

Illustrate pruning

In [None]:
import numpy as np
import bdms
import ete3
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib

In [None]:
with open('plot_settings.yml', 'r') as file:
    config = yaml.safe_load(file)

for key, value in config['paper']['rcParams'].items():
    plt.rcParams[key] = value

colors = config['paper']['legend_info']['sns_palette']
sns.palplot(colors)

In [None]:
birth = bdms.poisson.ConstantProcess(1.0)
death = bdms.poisson.DiscreteProcess([0.0, 2.0])
mutation = bdms.poisson.ConstantProcess(1.0)
mutator = bdms.mutators.DiscreteMutator((0, 1), np.array([[0, 1], [1, 0]]))

In [None]:
rng = np.random.default_rng(seed=2)

In [None]:
tree = bdms.Tree(state=0)

In [None]:
time_to_sampling = 4.9

tree.evolve(
    time_to_sampling,
    birth_process=birth,
    death_process=death,
    mutation_process=mutation,
    mutator=mutator,
    seed=rng,
)

In [None]:
tree.sample_survivors(n=5, seed=rng)

In [None]:
tree.ladderize()

In [None]:
def rotate_labels(node):
    ns = ete3.NodeStyle()
    if node.is_leaf() and node.name is not None:
        F = ete3.TextFace(node.name)
        F.rotation = -90
        F.margin_right = F.margin_left = 2.0
        node.add_face(F, 0, position="aligned")
        # node.img_style = ns

ts = ete3.TreeStyle()
ts.branch_vertical_margin = 2
ts.show_leaf_name = False
ts.show_scale = False
ts.rotation = 90
ts.scale=40
ts.layout_fn = rotate_labels

color_map = {0: colors[0], 1: colors[1]}

viz_kwargs = dict(
    color_map=color_map,
    h=4, units="in",
    tree_style=ts,
)

Customize the render function

In [None]:
import ete3

def render(tree, file_name: str, color_by="state", color_map=None, **kwargs):
    event_cache = tree.get_cached_content(store_attr="event", leaves_only=False)
    if (not tree._pruned) or (not tree._mutations_removed):
        for node in tree.traverse():
            nstyle = ete3.NodeStyle()
            if (
                tree._SURVIVAL_EVENT not in event_cache[node]
                and tree._SAMPLING_EVENT not in event_cache[node]
            ):
                nstyle["hz_line_type"] = 1
                nstyle["vt_line_type"] = 1
                nstyle["hz_line_width"] = 1
                nstyle["vt_line_width"] = 1
            elif tree._SAMPLING_EVENT not in event_cache[node]:
                nstyle["hz_line_type"] = 0
                nstyle["vt_line_type"] = 0
                nstyle["hz_line_width"] = 1
                nstyle["vt_line_width"] = 1
            else:
                nstyle["hz_line_type"] = 0
                nstyle["vt_line_type"] = 0
                nstyle["hz_line_width"] = 3
                nstyle["vt_line_width"] = 3
            if color_map is not None:
                assert color_by is not None
                nstyle["vt_line_color"] = color_map[getattr(node, color_by)]
                if not node.is_root():
                    assert node.up is not None
                    nstyle["hz_line_color"] = color_map[getattr(node.up, color_by)]
                nstyle["fgcolor"] = color_map[getattr(node, color_by)]
            nstyle["size"] = 0
            node.set_style(nstyle)
    else:
        for node in tree.traverse():
            nstyle = ete3.NodeStyle()
            nstyle["hz_line_width"] = 3
            nstyle["vt_line_width"] = 3
            if color_map is not None:
                assert color_by is not None
                nstyle["fgcolor"] = color_map[getattr(node, color_by)]
            if not node.is_root() and not getattr(node.faces, "branch-bottom"):
                node.add_face(tree._mutation_face, 0, position="branch-bottom")
            node.set_style(nstyle)

    return ete3.Tree.render(tree, file_name, **kwargs)

In [None]:
i = 0
for leaf in reversed(list(tree.iter_leaves())):
    if leaf.event == "sampling":
        i += 1
        leaf.name = i
    else:
        leaf.name = None


In [None]:
fig_directory = pathlib.Path("fig/schematic")
fig_directory.mkdir(parents=True, exist_ok=True)

In [None]:
render(tree, f"{fig_directory}/full.pdf", **viz_kwargs);
# render(tree, "%%inline", **viz_kwargs)

In [None]:
tree.prune_unsampled()

In [None]:
ts.layout_fn = []

In [None]:
viz_kwargs = dict(
    color_map=color_map,
    # w=10, units="in",
    tree_style=ts,
)

In [None]:
render(tree, f"{fig_directory}/pruned.pdf", **viz_kwargs);
# render(tree, "%%inline", **viz_kwargs)