In [None]:
import pathlib
import yaml

import ete3
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

In [None]:
newick_path = pathlib.Path("data/multitype_high_birth_fitness_huge/FE/trees/tree_2.nw")
fig_directory = pathlib.Path("fig")
fig_directory.mkdir(parents=True, exist_ok=True)

In [None]:
tree = ete3.Tree(str(newick_path), format = 1)
for node in tree.traverse():
    node.t = float(node.t)
    node.state = int(node.state)

In [None]:
max_t = max(node.t for node in tree.traverse())
color_by = "state"
n_leaves = len(tree)

w = 6
h = 1.5
dpi = 600
my_scale = dpi * h / (max_t - tree.t)

spacing_to_linewidth = 6.0
linewidth = int(
    min(w / (n_leaves + spacing_to_linewidth * (n_leaves - 1)) * dpi, 0.0125 * dpi)
)
branch_spacing = (w * dpi - n_leaves * linewidth) / (n_leaves - 1)
tree_width = linewidth * n_leaves + branch_spacing * (n_leaves - 1)

colors = ['#4c72b0', '#dd8452']
colormap = {
    node.name: colors[node.state-1]
    for node in tree.traverse()
}

for node in tree.traverse():
    nstyle = ete3.NodeStyle()
    nstyle["hz_line_width"] = linewidth  # Horizontal line width
    nstyle["vt_line_width"] = linewidth  # Vertical line width
    nstyle["size"] = 0.0
    nstyle["hz_line_color"] = colormap[node.name]
    nstyle["vt_line_color"] = colormap[node.name]
    nstyle["draw_descendants"] = True
    node.set_style(nstyle)

tree_style = ete3.TreeStyle()
tree_style.show_leaf_name = False
tree_style.show_scale = False
tree_style.min_leaf_separation = branch_spacing
tree_style.rotation = 90
tree_style.tree_width = tree_width
tree_style.margin_left = 0
tree_style.margin_right = 0
tree_style.margin_top = 0
tree_style.margin_bottom = 0
tree_style.scale = my_scale
tree_style.allow_face_overlap = True

ete3.Tree.render(
    tree, file_name = str(fig_directory / "multitype_huge.pdf"), units="in", dpi=dpi, tree_style=tree_style
);

# Make legend
colors = ['#4c72b0', '#dd8452']
labels = ["Low fitness", "High fitness"]

handles = [
    mpl.lines.Line2D(
        [0],
        [0],
        marker="s",
        color="w",
        label=label,
        markerfacecolor=color,
        markersize=10,
    )
    for color, label in zip(colors, labels)
]

fig = plt.figure(figsize=(3, 3))
fig.legend(handles=handles, labels=labels, loc="center", ncol=2)
fig.savefig(fig_directory / "multitype_huge_legend.pdf", bbox_inches="tight");