In [1]:
import numpy as np
from copy import deepcopy
import itertools
import matplotlib
from numpy import radians as rad
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve,convolve1d
import cmasher
import seaborn as sns
from collections import defaultdict
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools
import networkx as nx
import matplotlib.pyplot as plt
import random
from networkx.drawing.nx_pydot import graphviz_layout
from scipy.special import softmax
import itertools as it

In [2]:
import hloss_misc as _hloss

In [187]:
from collections import deque

# Softmax-margin loss

In [3]:
def create_exact_graph():
    G = nx.DiGraph()
    G.add_edges_from([
        (0, 2), (0, 1), 
        (2, 6), (2, 5), (1, 666),
        (6, 9), (5, 7), (5, 8), (666, 3), (666, 4),
        
    ])
    return G

In [4]:
G = create_exact_graph()

In [134]:
def make_graph_hloss():
    contigs = [
        '0;2;6;9',
        '0;2;5;7',
        '0;2;5;8',
        '0;1;666;3',
        '0;1;666;4',
    ]
    return _hloss.make_graph([_hloss.ContigTaxonomy.from_semicolon_sep(c, False) for c in contigs])

In [135]:
GRAPH_NODES, IND_NODES, TABLE_PARENT = make_graph_hloss()

In [109]:
hier = _hloss.Hierarchy(np.array(TABLE_PARENT))

In [287]:
def assign_parents(likelihoods, children):
    par_dict = defaultdict(int)
    for n, l in zip(children, likelihoods):
        parent = list(G.predecessors(n))[0]
        par_dict[parent] += l
    return par_dict


class SoftmaxMargin(Scene):
    
    def construct(self):
        LAYOUT_CONFIG = {"vertex_spacing": (1, 1)}
        VERTEX_CONF = {"radius": 0.25, "color": BLACK, "fill_opacity": 1, "stroke_opacity": 1, "stroke_color": WHITE, "stroke_width": 2}
        
        graph = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        leaves = [node for node in G.nodes() if G.degree(node) == 1]
        N_leaves = len(leaves)
        
        GT = leaves[2]
        INDEX_MAP = {n: i for i, n in enumerate(G.nodes())}
        ground_truth = [GT]
        gt = GT
        for _ in range(4):
            for p in G.predecessors(gt):
                ground_truth.append(p)
                gt = p

        leaves_vertices = VGroup(*[graph[l].copy() for l in leaves])

        self.play(FadeIn(graph))
        self.wait(1)

        all_nodes = VGroup(*[graph[l].copy() for l in G.nodes()])

        all_nodes.arrange(buff=0.5)

        self.play(*[Transform(graph[l], all_nodes[i]) for i, l in enumerate(G.nodes())] + [FadeOut(graph.edges[edge]) for edge in graph.edges])

        self.wait(1)

        squares = [Square().scale(0.5).move_to(n.get_center()) for n in all_nodes]
        boxes = VGroup(*squares)

        self.play(FadeIn(boxes))
        self.wait(1)

        self.play(boxes.animate.shift(DOWN))

        cmap = cmasher.get_sub_cmap(sns.dark_palette("#9CDCEB", as_cmap=True), 0, 1)

        logits = [random.uniform(0, 1) for _ in squares]
        logit_text = Text("logits", font_size=20).next_to(boxes, RIGHT)
        lik_text = Text("likelihoods", font_size=20).next_to(all_nodes, RIGHT)
        arrow = CurvedArrow(start_point=logit_text.get_right(), end_point=lik_text.get_right()).shift(RIGHT*0.05)
        softmax_text = MathTex(r"\frac{e^{l_{i}}}{\sum_{j=1}^K e^{l_{j}}}", font_size=20).move_to(arrow.get_right()).shift(RIGHT*0.6)

        color_anim = [square.animate.set_fill(RED, opacity=logits[i]) for i, square in enumerate(squares)]
        text_objs = [Text(f"{l:.2f}", font_size=30).move_to(b.get_center()) for l, b in zip(logits, boxes)] # remove later
        text_anim = [Write(t) for t in text_objs]
        self.play(*(color_anim + text_anim))
        self.play(Write(logit_text))
        self.wait(1)

        likelihoods = softmax(logits)
        text_objs_softmax1 = [Text(f"{l:.2f}", font_size=15).move_to(b.get_center()) for l, b in zip(likelihoods, all_nodes)]
        text_anim_softmax = [Write(t) for t in text_objs_softmax1]

        color_anim_softmax = [graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods[i]))) for i, l in enumerate(G.nodes())]
        
        self.play(*(color_anim_softmax + text_anim_softmax + [Write(softmax_text), Write(arrow), Write(lik_text)]))
        self.wait(1)

        all_texts = text_objs + [logit_text, lik_text, arrow]
        self.play(*[Unwrite(x) for x in all_texts] + [FadeOut(boxes)])


        self.wait(1)

        # Transform nodes and text back to their original graph layout
        original_layout = graph._layout
        labels = text_objs_softmax1

        # Animate nodes and their labels back to their original positions
        for i, l in enumerate(G.nodes()):
            graph[l].generate_target()
            graph[l].target.move_to(original_layout[l])
            labels[i].generate_target()
            labels[i].target.move_to(graph[l].target.get_center())
        
        self.play(
            *[MoveToTarget(graph[l]) for l in G.nodes()],
            *[MoveToTarget(labels[i]) for i, l in enumerate(G.nodes())],
            run_time=2,
            rate_func=smooth
        )
        self.play(*[FadeIn(graph.edges[edge]) for edge in graph.edges])
        
        self.wait(1)

        anims = []

        SCALE_COEF = 0.5

        def get_new_position(obj):
            return obj.get_center() * SCALE_COEF + UP * 2 + RIGHT * 5
    
        # Move and scale each node, its label, and its edges
        for i, l in enumerate(G.nodes()):
            # Calculate the new position for the node
            new_position = get_new_position(graph[l])
            node_anim = graph[l].animate.scale(SCALE_COEF).move_to(new_position)
            
            # Move the label to the new position of the node
            label_anim = labels[i].animate.scale(SCALE_COEF).move_to(new_position)
            
            # Append the animations for the node and label
            anims.append(node_anim)
            anims.append(label_anim)
        
        # Move the edges as well
        for edge in graph.edges:
            start_pos = get_new_position(graph[edge[0]])
            end_pos = get_new_position(graph[edge[1]])
        
            # Animate the edge movement along with the nodes
            edge_anim = graph.edges[edge].animate.put_start_and_end_on(start_pos, end_pos)
            anims.append(edge_anim)
        
        # Play all animations (nodes, labels, and edges) together
        self.play(*anims)
        self.wait(1)

        text_loss = Text("Loss", font_size=40).next_to(graph, UP)
        text_1 = Text("1.", font_size=40).next_to(graph, LEFT)
        text_2 = Text("2.", font_size=40).next_to(text_1, DOWN*6)
        text_q = Text("?", font_size=40).next_to(text_2, RIGHT)

        self.play(*[Write(t) for t in (text_loss, text_1)])
        self.wait(1)

        self.play(Write(text_2))
        self.play(Write(text_q))
        self.wait(1)

        G_copy = G.to_undirected()

        graph_margin = Graph(
            list(G_copy.nodes), 
            list(G_copy.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        
        self.play(*[Unwrite(t) for t in [text_loss, text_1, text_2, text_q] + labels] + [FadeOut(graph), FadeIn(graph_margin)])
        self.wait(1)

        NEW_GT = 5
        new_gt = graph_margin[NEW_GT]
        surbox = SurroundingRectangle(new_gt, color=YELLOW)

        self.play(FadeIn(surbox))
        self.wait(1)

        distances = {node: None for node in G_copy.nodes}  # Initialize distances as None
        distances[NEW_GT] = 0  # Distance to the start node is 0

        # Create a dictionary of labels for each node to display distances
        labels = {node: Text("", font_size=30).move_to(graph_margin[node].get_center()) for node in G_copy.nodes()}
        labels[NEW_GT].become(Text("0", font_size=30).move_to(graph_margin[NEW_GT].get_center()))  # Set start distance

        # Add labels to the scene
        self.play(*[FadeIn(label) for label in labels.values()])
        self.wait(1)

        # BFS traversal from the NEW_GT node
        queue = deque([NEW_GT])
        visited = set([NEW_GT])

        cmap = cmasher.get_sub_cmap(sns.dark_palette("#FC6255", as_cmap=True), 0, 1)

        texts = [labels[NEW_GT]]

        while queue:
            current = queue.popleft()
            current_distance = distances[current]

            # Visit each neighbor, including parents and children
            for neighbor in G_copy.neighbors(current):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append(neighbor)
                    distances[neighbor] = current_distance + 1
                    
                    # Animate the path to the neighbor
                    try:
                        edge_anim = Create(graph_margin.edges[(current, neighbor)])
                    except KeyError:
                        edge_anim = Create(graph_margin.edges[(neighbor, current)])
                    
                    # Update the distance label on the neighbor node
                    t = Text(str(distances[neighbor]), font_size=30).move_to(graph_margin[neighbor].get_center())
                    texts.append(t)
                    color_anim_margin = graph_margin[neighbor].animate.set_fill(rgba_to_color(cmap(distances[neighbor] / 5)))
                    
                    # Play the animations
                    self.play(color_anim_margin, Write(t), run_time=0.5)

        self.wait(1)

        reverse_colors = [graph_margin[l].animate.set_fill(BLACK) for i, l in enumerate(G_copy.nodes())]

        self.play(*(reverse_colors + [Unwrite(t) for t in texts]))
        
        self.wait(1)

        def info_dist_eqs():
            SHIFT_LEFT = LEFT*6.5
    
            eq_pt = [
                r"p - predicted, t - truth",
            ]
            pt = MathTex(*eq_pt, font_size=30).align_to(SHIFT_LEFT, LEFT).shift(UP*3.5)
    
            eq_lca = [
                r"lca(p, t) - last\ common\ ancestor",
            ]
            lca = MathTex(*eq_lca, font_size=30).align_to(SHIFT_LEFT, LEFT).shift(UP*3)
    
            eq_c = [
                r"C(p) - number\ of\ descendant\ leaves",
            ]
            count = MathTex(*eq_c, font_size=30).align_to(SHIFT_LEFT, LEFT).shift(UP*2.5)
    
            eq = [
                r"margin(p, t) = log(\frac{C(lca(p, t))}{C(p)C(t)})",
            ]
            nll_coef = MathTex(*eq, font_size=35).align_to(SHIFT_LEFT, LEFT).shift(UP*1.5)
            return [pt, lca, count, nll_coef]

        def incorrect_eqs():
            SHIFT_LEFT = LEFT*6.5
    
            eq_pt = [
                r"margin(p, t) = \begin{cases} 0\ if\ predicted\ is\ an\ ancestor\ of\ truth \\ 1\ otherwise \end{cases}",
            ]
            nll_coef = MathTex(*eq_pt, font_size=35).align_to(SHIFT_LEFT, LEFT).shift(UP*3)
            
            return [nll_coef]

        def play_margins(margin_type, text_fun, text_size, eqs_func):
            loss_fn = _hloss.MarginLoss(
                    hier,
                    with_leaf_targets=False,
                    hardness="soft",
                    margin=margin_type,
                    tau=0.01,
            )
    
            margins = np.array(loss_fn.margin[INDEX_MAP[NEW_GT] + 1])
            margins_norm = margins / np.max(margins)
    
            cmap = cmasher.get_sub_cmap(sns.dark_palette("#FC6255", as_cmap=True), 0, 1)
            color_anim_margin = [graph_margin[l].animate.set_fill(rgba_to_color(cmap(margins_norm[INDEX_MAP[l] + 1]))) for i, l in enumerate(G.nodes())]
            text_objs_margin = [Text(text_fun(margins[INDEX_MAP[l] + 1]), font_size=text_size).move_to(graph_margin[l].get_center())
                                for i, l in enumerate(G.nodes())]
            text_anim_margin = [Write(t) for t in text_objs_margin]
            eqs = eqs_func()
            eqs_anims = [FadeIn(e) for e in eqs]
            self.play(*(color_anim_margin + text_anim_margin + eqs_anims))
            self.wait(2)
            reverse_colors = [graph_margin[l].animate.set_fill(BLACK) for i, l in enumerate(G_copy.nodes())]
            self.play(*(reverse_colors + [Unwrite(t) for t in text_objs_margin] + [FadeOut(e) for e in eqs]))
            self.wait(1)

        for mt, ft, fs, es in zip(['info_dist', 'incorrect'][::-1], 
                          [lambda t: f"{t:.1f}", lambda t: f"{int(t)}"][::-1],
                          [20, 30][::-1],
                          [info_dist_eqs, incorrect_eqs][::-1],
                        ):
            play_margins(mt, ft, fs, es)
            
        # Trying to find values of p_i that will bring _the total sum_ to the smallest value
        # Which means that the values with big margins will be punished more

In [288]:
%manim SoftmaxMargin

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  