In [12]:
import numpy as np
from copy import deepcopy
import itertools
import matplotlib
from numpy import radians as rad
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve,convolve1d
import cmasher
import seaborn as sns
from collections import defaultdict
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools
import networkx as nx
import matplotlib.pyplot as plt
import random
from networkx.drawing.nx_pydot import graphviz_layout
from scipy.special import softmax

# Hierarchical loss

In [13]:
def create_random_tree_with_fixed_depth(number_of_children, branch_length=7):
    G = nx.DiGraph()

    root = 0
    G.add_node(root, value=round(random.uniform(0, 1), 2))

    current_node = 1

    parent_nodes = [root]
    
    for level in range(branch_length):
        new_parent_nodes = []
        
        for parent in parent_nodes:

            if level == 0:
                num_children = 2
            else:
                num_children = random.randint(1, number_of_children)

            for _ in range(num_children):
                G.add_node(current_node, value=round(random.uniform(0, 1), 2))
                G.add_edge(parent, current_node)
                
                new_parent_nodes.append(current_node)
                
                current_node += 1
        
        parent_nodes = new_parent_nodes

    return G

def plot_tree(G):
    """
    Plots the tree graph using networkx and matplotlib.

    Parameters:
    G (networkx.Graph): The tree graph to plot.
    """
    # Get the node values to display them as labels
    labels = {node: f'{G.nodes[node]["value"]:.2f}' for node in G.nodes()}

    # Generate a layout for visualizing the tree (hierarchical)
    pos = graphviz_layout(G, prog="dot")

    # Plot the tree
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=False, node_size=10, node_color='lightgreen', edge_color='gray')

    # Draw node labels (with their values)
    nx.draw_networkx_labels(G, pos, labels=labels)

    plt.title(f'Tree with Fixed Branch Length of {branch_length} and {G.number_of_nodes()} Nodes')
    plt.show()

In [14]:
G = create_random_tree_with_fixed_depth(2, branch_length=3)

In [15]:
def assign_parents(likelihoods, children):
    par_dict = defaultdict(int)
    for n, l in zip(children, likelihoods):
        parent = list(G.predecessors(n))[0]
        par_dict[parent] += l
    return par_dict


class TaxonomicTree(Scene):
    
    def construct(self):
        LAYOUT_CONFIG = {"vertex_spacing": (1, 1)}
        VERTEX_CONF = {"radius": 0.25, "color": BLACK, "fill_opacity": 1, "stroke_opacity": 1, "stroke_color": WHITE, "stroke_width": 2}
        
        PLAY_CREATION = 2
        PLAY_COLOR = 3
        
        graph = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        leaves = [node for node in G.nodes() if G.degree(node) == 1]
        N_leaves = len(leaves)
        self.play(Create(graph, run_time=PLAY_CREATION))
        self.wait(1)

        GT = leaves[2]
        ground_truth = [GT]
        gt = GT
        for _ in range(4):
            for p in G.predecessors(gt):
                ground_truth.append(p)
                gt = p

        self.play(*[graph[l].animate.set_fill(WHITE) for i, l in enumerate(ground_truth)])
        self.wait(1)

        self.play(*[graph[l].animate.set_fill(BLACK) for i, l in enumerate(ground_truth)])
        self.wait(1)
        
        squares = [Square().scale(0.5).move_to(graph[l].get_center()) for l in leaves]
        boxes = VGroup(*squares)

        self.play(FadeIn(boxes))
        self.wait(1)

        logit_text = Text("logits", font_size=30).next_to(boxes, RIGHT).shift(DOWN)
        lik_text = Text("likelihoods", font_size=30).next_to(boxes, RIGHT)
        
        arrow = CurvedArrow(start_point=logit_text.get_right(), end_point=lik_text.get_right()).shift(RIGHT*0.2)
        
        softmax_text = MathTex(r"\frac{e^{l_{i}}}{\sum_{j=1}^K e^{l_{j}}}", font_size=30).move_to(arrow.get_right()).shift(RIGHT*0.6)
        
        self.play(boxes.animate.shift(DOWN))
        self.play(Write(logit_text), Write(lik_text))
        self.wait(1)
        self.play(Write(softmax_text), Write(arrow))
        self.wait(1)
        
        self.play(Unwrite(logit_text), Unwrite(softmax_text), Unwrite(lik_text), Unwrite(arrow))
        self.wait(1)

        logits = [random.uniform(0, 1) for _ in squares]

        self.play(*[square.animate.set_fill(RED, opacity=logits[i]) for i, square in enumerate(squares)])
        self.wait(1)

        cmap = cmasher.get_sub_cmap(sns.dark_palette("#9CDCEB", as_cmap=True), 0, 1)
        
        likelihoods = softmax(logits)
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods[i]))) for i, l in enumerate(leaves)])
        self.wait(1)
        
        parents = assign_parents(likelihoods, leaves)
        nods, likelihoods_p = zip(*parents.items())
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)])
        self.wait(1)

        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)])
        self.wait(1)

        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)])
        self.wait(1)

        graph_copy = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        for i, l in enumerate(ground_truth):
            graph_copy[l].set_fill(WHITE) 
        SHIFT_CONST = 3
        self.play(
            graph_copy.animate.move_to(LEFT*SHIFT_CONST), 
            graph.animate.move_to(RIGHT*SHIFT_CONST), 
            boxes.animate.shift(RIGHT*SHIFT_CONST),
        )
        gt_text = Text("Ground truth", font_size=30).next_to(graph_copy, UP)
        lh_text = Text("Likelihoods", font_size=30).next_to(graph, UP)
        self.play(FadeIn(gt_text), FadeIn(lh_text))
        self.wait(1)

        self.play(FadeOut(boxes))
        self.wait(1)

        probs_eq_str = ["{{p_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        ys_eq_str = ["{{y_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        
        probs_eq = [MathTex(*s, font_size=40).next_to(graph[leaves[i]], DOWN) for i, s in enumerate(probs_eq_str)]
        ys_eq = [MathTex(*s, font_size=40).next_to(graph_copy[leaves[i]], DOWN) for i, s in enumerate(ys_eq_str)]
        
        probs = VGroup(*probs_eq)
        ys = VGroup(*ys_eq)
        self.play(FadeIn(probs), FadeIn(ys))
        self.wait(1)

        gt_leaf = graph_copy[GT]
        surbox = SurroundingRectangle(gt_leaf, color=YELLOW)
        gt_leaf2 = graph[GT]
        surbox2 = SurroundingRectangle(gt_leaf2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))
        self.wait(1)

        # equation
        eq = [
            "\L = -","(", "{{y_1}}", "log(", "{{p_1}}", ") + ", 
            "{{y_2}}", "log(", "{{p_2}}", ") + ", 
            "{{y_3}}", "log(", "{{p_3}}", ")", " + ", 
            "{{y_4}}", "log(", "{{p_4}}", ") + ", 
            "{{y_5}}", "log(", "{{p_5}}", ")", ")", 
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(1)

        self.play(*[ShrinkToCenter(nll[i]) for i, s in enumerate(eq) if i not in (0,11,12,13)])

        nll2 = MathTex("\L = -", "log(", "{{p_3}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll, nll2))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll2))
        self.play(graph_copy[GT].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[1]
        
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        nll3 = MathTex("\L = -", "log(", "{{p_3 + p_4}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll3))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll3))
        self.play(graph_copy[GT_GENUS].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[2]
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        nll3 = MathTex("\L = -", "log(", "{{p_3 + p_4 + p_5}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll3))
        self.wait(1)

In [16]:
%manim TaxonomicTree

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  