In [None]:
import ete3

In [None]:
tree = tree_info["tree"]
max_t = max(node.t for node in tree.traverse())
color_by = "state"
n_leaves = len(tree)

w = 6
h = 1.5
dpi = 600
my_scale = dpi * h / (max_t - tree.t)

spacing_to_linewidth = 6.0
linewidth = min(
    w / (n_leaves + spacing_to_linewidth * (n_leaves - 1)) * dpi, 0.0125 * dpi
)
branch_spacing = (w * dpi - n_leaves * linewidth) / (n_leaves - 1)
tree_width = linewidth * n_leaves + branch_spacing * (n_leaves - 1)

cmap = "coolwarm_r"
cmap = mpl.colormaps[cmap]
minrange = min(getattr(node, color_by) for node in tree.traverse())
maxrange = max(getattr(node, color_by) for node in tree.traverse())
halfrange = (maxrange - minrange) / 2
vcenter = (maxrange + minrange) / 2
norm = mpl.colors.CenteredNorm(
    vcenter=vcenter,
    halfrange=halfrange if halfrange > 0 else 1,
)
colormap = {
    node.name: mpl.colors.to_hex(cmap(norm(getattr(node, color_by))))
    for node in tree.traverse()
}

for node in tree_info["tree"].traverse():
    nstyle = ete3.NodeStyle()
    nstyle["hz_line_width"] = linewidth  # Horizontal line width
    nstyle["vt_line_width"] = linewidth  # Vertical line width
    nstyle["size"] = 0.0
    nstyle["hz_line_color"] = colormap[node.name]
    nstyle["vt_line_color"] = colormap[node.name]
    nstyle["draw_descendants"] = True
    node.set_style(nstyle)

tree_style = ete3.TreeStyle()
tree_style.show_leaf_name = False
tree_style.show_scale = False
tree_style.min_leaf_separation = branch_spacing
tree_style.rotation = 90
tree_style.tree_width = tree_width
tree_style.margin_left = 0
tree_style.margin_right = 0
tree_style.margin_top = 0
tree_style.margin_bottom = 0
tree_style.scale = my_scale
tree_style.allow_face_overlap = True

ete3.Tree.render(
    tree, file_name="fig/multitype_huge.svg", units="in", dpi=dpi, tree_style=tree_style
)

# Make legend
colors = [mpl.colors.to_hex(cmap(norm(x))) for x in [1, 2]]
labels = ["Low fitness", "High fitness"]

handles = [
    mpl.lines.Line2D(
        [0],
        [0],
        marker="s",
        color="w",
        label=label,
        markerfacecolor=color,
        markersize=10,
    )
    for color, label in zip(colors, labels)
]

fig = plt.figure(figsize=(3, 3))
fig.legend(handles=handles, labels=labels, loc="center", ncol=2)
fig.savefig("fig/multitype_huge_legend.pdf", bbox_inches="tight")