In [12]:
import numpy as np
from copy import deepcopy
import itertools
import matplotlib
from numpy import radians as rad
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve,convolve1d
import cmasher
import seaborn as sns
from collections import defaultdict
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools
import networkx as nx
import matplotlib.pyplot as plt
import random
from networkx.drawing.nx_pydot import graphviz_layout
from scipy.special import softmax

# Hierarchical loss

In [128]:
def create_exact_graph():
    G = nx.DiGraph()
    G.add_edges_from([
        (0, 2), (0, 1), 
        (2, 6), (2, 5), (1, 666),
        (6, 9), (5, 7), (5, 8), (666, 3), (666, 4),
        
    ])
    return G

In [129]:
G = create_exact_graph()

In [130]:
def create_random_tree_with_fixed_depth(number_of_children, branch_length=7):
    G = nx.DiGraph()

    root = 0
    G.add_node(root, value=round(random.uniform(0, 1), 2))

    current_node = 1

    parent_nodes = [root]
    
    for level in range(branch_length):
        new_parent_nodes = []
        
        for parent in parent_nodes:

            if level == 0:
                num_children = 2
            else:
                num_children = random.randint(1, number_of_children)

            for _ in range(num_children):
                G.add_node(current_node, value=round(random.uniform(0, 1), 2))
                G.add_edge(parent, current_node)
                
                new_parent_nodes.append(current_node)
                
                current_node += 1
        
        parent_nodes = new_parent_nodes

    return G

def plot_tree(G):
    """
    Plots the tree graph using networkx and matplotlib.

    Parameters:
    G (networkx.Graph): The tree graph to plot.
    """
    # Generate a layout for visualizing the tree (hierarchical)
    pos = graphviz_layout(G, prog="dot")

    # Plot the tree
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=False, node_size=10, node_color='lightgreen', edge_color='gray')

    plt.title(f'Tree with {G.number_of_nodes()} Nodes')
    plt.show()

In [309]:
def assign_parents(likelihoods, children):
    par_dict = defaultdict(int)
    for n, l in zip(children, likelihoods):
        parent = list(G.predecessors(n))[0]
        par_dict[parent] += l
    return par_dict


class TaxonomicTree(Scene):
    
    def construct(self):
        LAYOUT_CONFIG = {"vertex_spacing": (1, 1)}
        VERTEX_CONF = {"radius": 0.25, "color": BLACK, "fill_opacity": 1, "stroke_opacity": 1, "stroke_color": WHITE, "stroke_width": 2}
        
        graph = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        leaves = [node for node in G.nodes() if G.degree(node) == 1]
        N_leaves = len(leaves)
        
        GT = leaves[2]
        ground_truth = [GT]
        gt = GT
        for _ in range(4):
            for p in G.predecessors(gt):
                ground_truth.append(p)
                gt = p

        leaves_vertices = VGroup(*[graph[l].copy() for l in leaves])

        species = ['S. elodea', 'S. melonis', 'E. coli', 'P. aeruginosa', 'M. albus']

        keep = [
            Text("Family", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*3), 
            Text("Genus", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*2), 
            Text("Species", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*1),
        ]

        leaves_vertices.arrange(buff=1.0).shift(DOWN)
        self.play(*[FadeIn(leaves_vertices)] + \
                  [Write(Text(t, font_size=20).next_to(leaves_vertices[i], DOWN)) for i, t in enumerate(species)] + \
                 [FadeIn(t) for t in keep])
        self.wait(1)

        ann1 = ["Sphingomonadaceae", "Sphingomonas", "S. melonis"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text] + [leaves_vertices[1].animate.set_fill(WHITE)], )
        self.wait(1)
        self.play(*[FadeOut(t) for t in new_text] + [leaves_vertices[1].animate.set_fill(BLACK)])
        self.wait(1)

        ann1 = ["Pseudomonadaceae", "Pseudomonas", "P. aeruginosa"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text] + [leaves_vertices[3].animate.set_fill(WHITE)])
        self.wait(1)
        self.play(*[FadeOut(t) for t in new_text] + [leaves_vertices[3].animate.set_fill(BLACK)])
        self.wait(1)

        ann1 = ["Sphingomonadaceae", "Sphingomonas", "???"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text])
        self.wait(1)

        self.play(
            *[FadeOut(mob)for mob in self.mobjects]
        )

        pars = assign_parents([0 for _ in leaves], leaves)
        genus, ls = zip(*pars.items())
        pars2 = assign_parents(ls, genus)
        family, _ = zip(*pars2.items())

        sp_boxes = VGroup(*[graph[l] for l in leaves])
        genus_boxes = VGroup(*[graph[b] for b in genus])
        family_boxes = VGroup(*[graph[b] for b in family])
        squares = [Square().scale(0.5).move_to(graph[l].get_center()) for l in leaves]
        boxes = VGroup(*squares)

        sp_text = Text("Species", font_size=30).next_to(sp_boxes, LEFT)
        ge_text = Text("Genus", font_size=30).next_to(genus_boxes, LEFT)
        fa_text = Text("Family", font_size=30).next_to(family_boxes, LEFT)


        ##### Graph appears #####
        self.play(FadeIn(graph))
        self.play(FadeIn(sp_text), FadeIn(ge_text), FadeIn(fa_text))
        self.wait(1)
        
        self.play(FadeOut(sp_text), FadeOut(ge_text), FadeOut(fa_text))
        self.play(FadeIn(boxes))
        self.wait(1)

        self.play(boxes.animate.shift(DOWN))

        logits = [random.uniform(0, 1) for _ in squares]

        cmap = cmasher.get_sub_cmap(sns.dark_palette("#9CDCEB", as_cmap=True), 0, 1)
        nn_text = Text("Neural network\noutput layer", font_size=30).next_to(boxes, LEFT)
        logit_text = Text("logits", font_size=30).next_to(boxes, RIGHT)
        lik_text = Text("likelihoods", font_size=30).next_to(graph[leaves[0]], RIGHT)
        arrow = CurvedArrow(start_point=logit_text.get_right(), end_point=lik_text.get_right()).shift(RIGHT*0.2)
        softmax_text = MathTex(r"\frac{e^{l_{i}}}{\sum_{j=1}^K e^{l_{j}}}", font_size=30).move_to(arrow.get_right()).shift(RIGHT*0.6)

        color_anim = [square.animate.set_fill(RED, opacity=logits[i]) for i, square in enumerate(squares)]
        text_objs = [Text(f"{l:.2f}", font_size=30).move_to(b.get_center()) for l, b in zip(logits, boxes)] # remove later
        text_anim = [Write(t) for t in text_objs]
        self.play(Write(nn_text))
        self.play(*(color_anim + text_anim))
        self.wait(1)
        self.play(Write(logit_text), Unwrite(nn_text))
        self.wait(1)
        
        likelihoods = softmax(logits)
        text_objs_softmax1 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods, leaves)] # remove later
        text_anim_softmax = [Write(t) for t in text_objs_softmax1]

        color_anim_softmax = [graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods[i]))) for i, l in enumerate(leaves)]
        
        self.play(*(color_anim_softmax + text_anim_softmax + [Write(softmax_text), Write(arrow), Write(lik_text)]))
        self.wait(1)
        
        self.play(Unwrite(logit_text), Unwrite(softmax_text), Unwrite(lik_text), Unwrite(arrow))
        self.wait(1)

        def get_middle(i):
            return (graph[leaves[i]].get_center() + graph[leaves[i-1]].get_center()) / 2

        plus1 = MathTex(r"+", font_size=30).move_to(get_middle(4))
        plus2 = MathTex(r"+", font_size=30).move_to(get_middle(3))
        plus3 = MathTex(r"+", font_size=30).move_to(get_middle(2))
        plus4 = MathTex(r"+", font_size=30).move_to(get_middle(1))

        equals_1 = MathTex(r"= 1", font_size=30).next_to(graph[leaves[0]], RIGHT)

        self.play(*[Write(t) for t in [plus1, plus2, plus3, plus4, equals_1]])
        self.wait(1)
        self.play(*[Unwrite(t) for t in [plus2, plus4, equals_1]])
        self.wait(1)

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods, leaves)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars2 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars2]
        
        self.play(*([graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars + [Unwrite(t) for t in [plus1, plus3]]))

        def get_middle_parents(i):
            return (graph[nods[i]].get_center() + graph[nods[i-1]].get_center()) / 2
        
        plusp1 = MathTex(r"+", font_size=30).move_to(get_middle_parents(2))
        plusp2 = MathTex(r"+", font_size=30).move_to(get_middle_parents(1))

        equals_1 = MathTex(r"= 1", font_size=30).next_to(graph[nods[0]], RIGHT)

        self.play(*[Write(t) for t in [plusp1, plusp2, equals_1]])
        self.play(*[Unwrite(t) for t in [plusp1, equals_1]])

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars3 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars3]
        self.play(*([graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars+ [Unwrite(plusp2)]))

        plusp1 = MathTex(r"+", font_size=30).move_to(get_middle_parents(1))

        self.play(*[Write(t) for t in [plusp1]])

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars4 = [Text(f"1", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars4]
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars + [Unwrite(plusp1)])
        self.wait(1)

        ##### Second graph #####
        graph_copy = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        for i, l in enumerate(ground_truth):
            graph_copy[l].set_fill(WHITE) 
        SHIFT_CONST = 3
        all_texts = text_objs + text_objs_softmax1 + text_objs_pars2 + text_objs_pars3 + text_objs_pars4
        self.play(*[FadeOut(x) for x in all_texts])
        self.play(
            graph_copy.animate.move_to(LEFT*SHIFT_CONST), 
            graph.animate.move_to(RIGHT*SHIFT_CONST), 
            boxes.animate.shift(RIGHT*SHIFT_CONST),
            FadeOut(boxes),
        )
        gt_text = Text("Ground truth", font_size=30).next_to(graph_copy, UP)
        lh_text = Text("Likelihoods", font_size=30).next_to(graph, UP)
        self.play(FadeIn(gt_text), FadeIn(lh_text))
        self.wait(1)

        probs_eq_str = ["{{p_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        ys_eq_str = ["{{y_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        
        probs_eq = [MathTex(*s, font_size=40).next_to(graph[leaves[i]], DOWN) for i, s in enumerate(probs_eq_str)]
        ys_eq = [MathTex(*s, font_size=40).next_to(graph_copy[leaves[i]], DOWN) for i, s in enumerate(ys_eq_str)]
        
        probs = VGroup(*probs_eq)
        ys = VGroup(*ys_eq)
        self.play(FadeIn(probs), FadeIn(ys))
        self.wait(1)

        gt_leaf = graph_copy[GT]
        surbox = SurroundingRectangle(gt_leaf, color=YELLOW)
        gt_leaf2 = graph[GT]
        surbox2 = SurroundingRectangle(gt_leaf2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))
        self.wait(1)

        # equation
        eq = [
            "\L = -", "(", "{{y_1}}", "log(", "{{p_1}}", ") + ",
            "{{y_2}}", "log(", "{{p_2}}", ") + ",
            "{{y_3}}", "log(", "{{p_3}}", ")", " + ",
            "{{y_4}}", "log(", "{{p_4}}", ") + ",
            "{{y_5}}", "log(", "{{p_5}}", ")", ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", "0\cdot", "log(", "{{p_1}}", ") + ",
            "0\cdot", "log(", "{{p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3}}", ")", " + ",
            "0\cdot", "log(", "{{p_4}}", ") + ",
            "0\cdot", "log(", "{{p_5}}", ")", ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0,11,12,13)])

        nll2 = MathTex("\L = -", "log(", "{{p_3}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll2))
        self.play(graph_copy[GT].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[1]
        
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        # equation
        eq = [
            "\L = -", "(", 
            "{{y_{12}}}", "log(", "{{p_1 + p_2}}", ") + ",
            "{{y_{34}}}", "log(", "{{p_3 + p_4}}", ")", " + ",
            "{{y_5}}", "log(", "{{p_5}}", ")", 
            ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", 
            "0\cdot", "log(", "{{p_1 + p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3 + p_4}}", ")", " + ",
            "0\cdot", "log(", "{{p_5}}", ")", 
            ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0, 7, 8, 9)])

        nll2 = MathTex("\L = -", "log(", "{{p_3 + p_4}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll2))
        self.play(graph_copy[GT_GENUS].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[2]
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        # equation
        eq = [
            "\L = -", "(", 
            "{{y_{12}}}", "log(", "{{p_1 + p_2}}", ") + ",
            "{{y_{345}}}", "log(", "{{p_3 + p_4 + p_5}}", ")",
            ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", 
            "0\cdot", "log(", "{{p_1 + p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3 + p_4 + p_5}}", ")",
            ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0, 7, 8, 9)])

        nll2 = MathTex("\L = -", "log(", "{{p_3 + p_4 + p_5}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

In [None]:
%manim TaxonomicTree

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

Animation 65: FadeOut(Group):  90%|███████████████████████████████████████████████████████████████████████████████████████████████████▉           | 54/60 [00:01<00:00, 44.90it/s]