In [12]:
import numpy as np
from copy import deepcopy
import itertools
import matplotlib
from numpy import radians as rad
from matplotlib.animation import FuncAnimation
from scipy.ndimage import convolve,convolve1d
import cmasher
import seaborn as sns
from collections import defaultdict
from manim import *
import networkx as nx
from scipy.interpolate import interp1d
import itertools
import networkx as nx
import matplotlib.pyplot as plt
import random
from networkx.drawing.nx_pydot import graphviz_layout
from scipy.special import softmax

In [588]:
import itertools as it

# Hierarchical loss

In [633]:
def create_exact_graph():
    G = nx.DiGraph()
    G.add_edges_from([
        (0, 2), (0, 1), 
        (2, 6), (2, 5), (1, 666),
        (6, 9), (5, 7), (5, 8), (666, 3), (666, 4),
        
    ])
    return G

In [634]:
G = create_exact_graph()

In [130]:
def create_random_tree_with_fixed_depth(number_of_children, branch_length=7):
    G = nx.DiGraph()

    root = 0
    G.add_node(root, value=round(random.uniform(0, 1), 2))

    current_node = 1

    parent_nodes = [root]
    
    for level in range(branch_length):
        new_parent_nodes = []
        
        for parent in parent_nodes:

            if level == 0:
                num_children = 2
            else:
                num_children = random.randint(1, number_of_children)

            for _ in range(num_children):
                G.add_node(current_node, value=round(random.uniform(0, 1), 2))
                G.add_edge(parent, current_node)
                
                new_parent_nodes.append(current_node)
                
                current_node += 1
        
        parent_nodes = new_parent_nodes

    return G

def plot_tree(G):
    """
    Plots the tree graph using networkx and matplotlib.

    Parameters:
    G (networkx.Graph): The tree graph to plot.
    """
    # Generate a layout for visualizing the tree (hierarchical)
    pos = graphviz_layout(G, prog="dot")

    # Plot the tree
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=False, node_size=10, node_color='lightgreen', edge_color='gray')

    plt.title(f'Tree with {G.number_of_nodes()} Nodes')
    plt.show()

In [655]:
def assign_parents(likelihoods, children):
    par_dict = defaultdict(int)
    for n, l in zip(children, likelihoods):
        parent = list(G.predecessors(n))[0]
        par_dict[parent] += l
    return par_dict


class TaxonomicTree(Scene):
    
    def construct(self):
        network = NeuralNetworkMobject([10, 10, 10, 5])
        network.label_outputs_text(['S. elodea', 'S. melonis', 'E. coli', 'P. aeruginosa', 'M. albus'])
        
        brace1 = Brace(network.layers[0][0][:4], LEFT)
        text1 = Text('abundances', font_size=40).next_to(brace1, LEFT)
        brace2 = Brace(network.layers[0][0][4:], LEFT)
        text2 = Text('tetranucleotide\nfrequencies', font_size=40).next_to(brace2, LEFT)
        
        self.play(Write(network), Write(brace1), Write(brace2), Write(text1), Write(text2))

        output = network.layers[-1][0].copy()
        output_labels = network.output_labels.copy()

        self.wait(1)

        LAYOUT_CONFIG = {"vertex_spacing": (1, 1)}
        VERTEX_CONF = {"radius": 0.25, "color": BLACK, "fill_opacity": 1, "stroke_opacity": 1, "stroke_color": WHITE, "stroke_width": 2}
        
        graph = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        leaves = [node for node in G.nodes() if G.degree(node) == 1]
        N_leaves = len(leaves)
        
        GT = leaves[2]
        ground_truth = [GT]
        gt = GT
        for _ in range(4):
            for p in G.predecessors(gt):
                ground_truth.append(p)
                gt = p

        leaves_vertices = VGroup(*[graph[l].copy() for l in leaves])

        species = ['S. elodea', 'S. melonis', 'E. coli', 'P. aeruginosa', 'M. albus']
        leaves_vertices.arrange(buff=1.0).shift(DOWN)
        species_texts = [MathTex(t).set_height(output_labels[0].get_height()).next_to(leaves_vertices[i], DOWN) for i, t in enumerate(species)]
        fadeouts = [FadeOut(network), FadeOut(brace1), FadeOut(brace2), FadeOut(text1), FadeOut(text2)]
        
        self.play(*[Transform(output, leaves_vertices)] + fadeouts + [Transform(a, b) for a, b in zip(output_labels, species_texts)])

        self.wait(1)
        
        keep = [
            Text("Family", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*3), 
            Text("Genus", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*2), 
            Text("Species", font_size=30, weight=BOLD).move_to(LEFT*2 + UP*1),
        ]

        ann1 = ["Sphingomonadaceae", "Sphingomonas", "S. melonis"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text] + [FadeIn(t) for t in keep] + [leaves_vertices[1].animate.set_fill(WHITE)], )
        self.wait(1)
        self.play(*[FadeOut(t) for t in new_text] + [leaves_vertices[1].animate.set_fill(BLACK)])
        self.wait(1)

        ann1 = ["Pseudomonadaceae", "Pseudomonas", "P. aeruginosa"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text] + [leaves_vertices[3].animate.set_fill(WHITE)])
        self.wait(1)
        self.play(*[FadeOut(t) for t in new_text] + [leaves_vertices[3].animate.set_fill(BLACK)])
        self.wait(1)

        ann1 = ["Sphingomonadaceae", "Sphingomonas", "???"]
        new_text = [
            Text(a, font_size=30).move_to(RIGHT*2 + UP*(3-i)) for i, a in enumerate(ann1)
        ]
        self.play(*[Write(t) for t in new_text])
        self.wait(1)

        self.play(
            *[FadeOut(mob)for mob in self.mobjects]
        )

        pars = assign_parents([0 for _ in leaves], leaves)
        genus, ls = zip(*pars.items())
        pars2 = assign_parents(ls, genus)
        family, _ = zip(*pars2.items())

        sp_boxes = VGroup(*[graph[l] for l in leaves])
        genus_boxes = VGroup(*[graph[b] for b in genus])
        family_boxes = VGroup(*[graph[b] for b in family])
        squares = [Square().scale(0.5).move_to(graph[l].get_center()) for l in leaves]
        boxes = VGroup(*squares)

        sp_text = Text("Species", font_size=30).next_to(sp_boxes, LEFT)
        ge_text = Text("Genus", font_size=30).next_to(genus_boxes, LEFT)
        fa_text = Text("Family", font_size=30).next_to(family_boxes, LEFT)


        ##### Graph appears #####
        self.play(FadeIn(graph))
        self.play(FadeIn(sp_text), FadeIn(ge_text), FadeIn(fa_text))
        self.wait(1)
        
        self.play(FadeOut(sp_text), FadeOut(ge_text), FadeOut(fa_text))
        self.play(FadeIn(boxes))
        self.wait(1)

        self.play(boxes.animate.shift(DOWN))

        logits = [random.uniform(0, 1) for _ in squares]

        cmap = cmasher.get_sub_cmap(sns.dark_palette("#9CDCEB", as_cmap=True), 0, 1)
        nn_text = Text("Neural network\noutput layer", font_size=30).next_to(boxes, LEFT)
        logit_text = Text("logits", font_size=30).next_to(boxes, RIGHT)
        lik_text = Text("likelihoods", font_size=30).next_to(graph[leaves[0]], RIGHT)
        arrow = CurvedArrow(start_point=logit_text.get_right(), end_point=lik_text.get_right()).shift(RIGHT*0.2)
        softmax_text = MathTex(r"\frac{e^{l_{i}}}{\sum_{j=1}^K e^{l_{j}}}", font_size=30).move_to(arrow.get_right()).shift(RIGHT*0.6)

        color_anim = [square.animate.set_fill(RED, opacity=logits[i]) for i, square in enumerate(squares)]
        text_objs = [Text(f"{l:.2f}", font_size=30).move_to(b.get_center()) for l, b in zip(logits, boxes)] # remove later
        text_anim = [Write(t) for t in text_objs]
        self.play(Write(nn_text))
        self.play(*(color_anim + text_anim))
        self.wait(1)
        self.play(Write(logit_text), Unwrite(nn_text))
        self.wait(1)
        
        likelihoods = softmax(logits)
        text_objs_softmax1 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods, leaves)] # remove later
        text_anim_softmax = [Write(t) for t in text_objs_softmax1]

        color_anim_softmax = [graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods[i]))) for i, l in enumerate(leaves)]
        
        self.play(*(color_anim_softmax + text_anim_softmax + [Write(softmax_text), Write(arrow), Write(lik_text)]))
        self.wait(1)
        
        self.play(Unwrite(logit_text), Unwrite(softmax_text), Unwrite(lik_text), Unwrite(arrow))
        self.wait(1)

        def get_middle(i):
            return (graph[leaves[i]].get_center() + graph[leaves[i-1]].get_center()) / 2

        plus1 = MathTex(r"+", font_size=30).move_to(get_middle(4))
        plus2 = MathTex(r"+", font_size=30).move_to(get_middle(3))
        plus3 = MathTex(r"+", font_size=30).move_to(get_middle(2))
        plus4 = MathTex(r"+", font_size=30).move_to(get_middle(1))

        equals_1 = MathTex(r"= 1", font_size=30).next_to(graph[leaves[0]], RIGHT)

        self.play(*[Write(t) for t in [plus1, plus2, plus3, plus4, equals_1]])
        self.wait(1)
        self.play(*[Unwrite(t) for t in [plus2, plus4, equals_1]])
        self.wait(1)

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods, leaves)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars2 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars2]
        
        self.play(*([graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars + [Unwrite(t) for t in [plus1, plus3]]))

        def get_middle_parents(i):
            return (graph[nods[i]].get_center() + graph[nods[i-1]].get_center()) / 2
        
        plusp1 = MathTex(r"+", font_size=30).move_to(get_middle_parents(2))
        plusp2 = MathTex(r"+", font_size=30).move_to(get_middle_parents(1))

        equals_1 = MathTex(r"= 1", font_size=30).next_to(graph[nods[0]], RIGHT)

        self.play(*[Write(t) for t in [plusp1, plusp2, equals_1]])
        self.play(*[Unwrite(t) for t in [plusp1, equals_1]])

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars3 = [Text(f"{l:.2f}", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars3]
        self.play(*([graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars+ [Unwrite(plusp2)]))

        plusp1 = MathTex(r"+", font_size=30).move_to(get_middle_parents(1))

        self.play(*[Write(t) for t in [plusp1]])

        ##### Repeat for parents #####
        parents = assign_parents(likelihoods_p, nods)
        nods, likelihoods_p = zip(*parents.items())
        text_objs_pars4 = [Text(f"1", font_size=15).move_to(graph[b].get_center()) for l, b in zip(likelihoods_p, nods)] # remove later
        text_anim_pars = [Write(t) for t in text_objs_pars4]
        self.play(*[graph[l].animate.set_fill(rgba_to_color(cmap(likelihoods_p[i]))) for i, l in enumerate(nods)] + text_anim_pars + [Unwrite(plusp1)])
        self.wait(1)

        ##### Second graph #####
        graph_copy = Graph(
            list(G.nodes), 
            list(G.edges), 
            layout="tree", 
            root_vertex=0, 
            layout_config=LAYOUT_CONFIG,
            vertex_config=VERTEX_CONF,
        )
        for i, l in enumerate(ground_truth):
            graph_copy[l].set_fill(WHITE) 
        SHIFT_CONST = 3
        all_texts = text_objs + text_objs_softmax1 + text_objs_pars2 + text_objs_pars3 + text_objs_pars4
        self.play(*[FadeOut(x) for x in all_texts])
        self.play(
            graph_copy.animate.move_to(LEFT*SHIFT_CONST), 
            graph.animate.move_to(RIGHT*SHIFT_CONST), 
            boxes.animate.shift(RIGHT*SHIFT_CONST),
            FadeOut(boxes),
        )
        gt_text = Text("Ground truth", font_size=30).next_to(graph_copy, UP)
        lh_text = Text("Likelihoods", font_size=30).next_to(graph, UP)
        self.play(FadeIn(gt_text), FadeIn(lh_text))
        self.wait(1)

        probs_eq_str = ["{{p_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        ys_eq_str = ["{{y_" + f"{5-i}"+ "}}" for i in range(len(leaves))]
        
        probs_eq = [MathTex(*s, font_size=40).next_to(graph[leaves[i]], DOWN) for i, s in enumerate(probs_eq_str)]
        ys_eq = [MathTex(*s, font_size=40).next_to(graph_copy[leaves[i]], DOWN) for i, s in enumerate(ys_eq_str)]
        
        probs = VGroup(*probs_eq)
        ys = VGroup(*ys_eq)
        self.play(FadeIn(probs), FadeIn(ys))
        self.wait(1)

        gt_leaf = graph_copy[GT]
        surbox = SurroundingRectangle(gt_leaf, color=YELLOW)
        gt_leaf2 = graph[GT]
        surbox2 = SurroundingRectangle(gt_leaf2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))
        self.wait(1)

        # equation
        eq = [
            "\L = -", "(", "{{y_1}}", "log(", "{{p_1}}", ") + ",
            "{{y_2}}", "log(", "{{p_2}}", ") + ",
            "{{y_3}}", "log(", "{{p_3}}", ")", " + ",
            "{{y_4}}", "log(", "{{p_4}}", ") + ",
            "{{y_5}}", "log(", "{{p_5}}", ")", ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", "0\cdot", "log(", "{{p_1}}", ") + ",
            "0\cdot", "log(", "{{p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3}}", ")", " + ",
            "0\cdot", "log(", "{{p_4}}", ") + ",
            "0\cdot", "log(", "{{p_5}}", ")", ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0,11,12,13)])

        nll2 = MathTex("\L = -", "log(", "{{p_3}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll2))
        self.play(graph_copy[GT].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[1]
        
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        # equation
        eq = [
            "\L = -", "(", 
            "{{y_{12}}}", "log(", "{{p_1 + p_2}}", ") + ",
            "{{y_{34}}}", "log(", "{{p_3 + p_4}}", ")", " + ",
            "{{y_5}}", "log(", "{{p_5}}", ")", 
            ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", 
            "0\cdot", "log(", "{{p_1 + p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3 + p_4}}", ")", " + ",
            "0\cdot", "log(", "{{p_5}}", ")", 
            ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0, 7, 8, 9)])

        nll2 = MathTex("\L = -", "log(", "{{p_3 + p_4}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

        self.play(FadeOut(surbox, surbox2, nll2))
        self.play(graph_copy[GT_GENUS].animate.set_fill(BLACK))

        GT_GENUS = ground_truth[2]
        gt_genus = graph_copy[GT_GENUS]
        surbox = SurroundingRectangle(gt_genus, color=YELLOW)
        gt_genus2 = graph[GT_GENUS]
        surbox2 = SurroundingRectangle(gt_genus2, color=YELLOW)
        self.play(FadeIn(surbox, surbox2))

        # equation
        eq = [
            "\L = -", "(", 
            "{{y_{12}}}", "log(", "{{p_1 + p_2}}", ") + ",
            "{{y_{345}}}", "log(", "{{p_3 + p_4 + p_5}}", ")",
            ")",
        ]
        nll = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(FadeIn(nll))
        self.wait(2)

        eq = [
            "\L = -", "(", 
            "0\cdot", "log(", "{{p_1 + p_2}}", ") + ",
            "1\cdot", "log(", "{{p_3 + p_4 + p_5}}", ")",
            ")",
        ]
        nll_coef = MathTex(*eq, font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingTex(nll, nll_coef))
        self.wait(2)

        self.play(*[ShrinkToCenter(nll_coef[i]) for i, s in enumerate(eq) if i not in (0, 7, 8, 9)])

        nll2 = MathTex("\L = -", "log(", "{{p_3 + p_4 + p_5}}", ")",
              font_size=40).move_to(DOWN*2.8)
        self.play(TransformMatchingShapes(nll_coef, nll2))
        self.wait(1)

In [None]:
%manim TaxonomicTree

  label.set_height(0.75*neuron.get_height())
  label.set_height(0.75*neuron.get_height())
  label.shift((neuron.get_width() + label.get_width()/2)*RIGHT)


                                                                                                                                                                                  

  species_texts = [MathTex(t).set_height(output_labels[0].get_height()).next_to(leaves_vertices[i], DOWN) for i, t in enumerate(species)]
  species_texts = [MathTex(t).set_height(output_labels[0].get_height()).next_to(leaves_vertices[i], DOWN) for i, t in enumerate(species)]


                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

Animation 22: Write(Text('Neural network\noutput layer')):  88%|██████████████████████████████████████████████████████████████████████          | 105/120 [00:01<00:00, 67.01it/s]

# TNFs

In [312]:
ALPHABET = ('A', 'T', 'C', 'G')

def make_dna(length=20):
    return ''.join(np.random.choice(ALPHABET, length))

In [385]:
class Features(Scene):
    
    def construct(self):
        dna_string = 'CTACTACG'
        t2c_map = {'A': RED, 'C': BLUE, 'G': YELLOW, 'T': GREEN}
        dna_text = Text(dna_string, t2c=t2c_map, font_size=50).shift(UP*2)
        br = Brace(dna_text[0:4], direction=DOWN)
        self.play(FadeIn(dna_text), Write(br))
        self.wait(1)
        counters = {}
        counters[dna_string[0:4]] = 1
        counter_label = Text(dna_string[0:4], font_size=50, t2c=t2c_map)
        counter = Integer(1, color=WHITE, font_size=50).next_to(counter_label, RIGHT)
        self.play(Write(counter_label), Write(counter))
        prev_line = counter_label
        count_integers = {dna_string[0:4]: counter}
        count_labels = {dna_string[0:4]: counter_label}
        for i in range(1, len(dna_string)-3):
            brace_target = Brace(dna_text[i:i+4], direction=DOWN)  # Attach brace to word
            self.play(
                ReplacementTransform(br, brace_target),
                run_time=1
            )
            if dna_string[i:i+4] in counters:
                counters[dna_string[i:i+4]] += 1
                self.play(count_integers[dna_string[i:i+4]].animate.set_value(counters[dna_string[i:i+4]]))
            else:
                counters[dna_string[i:i+4]] = 1
                counter_label = Text(dna_string[i:i+4], font_size=50, t2c=t2c_map).next_to(prev_line, DOWN)
                counter = Integer(1, color=WHITE, font_size=50).next_to(counter_label, RIGHT*2)
                prev_line = counter_label
                self.play(Write(counter_label), Write(counter))
                count_integers[dna_string[i:i+4]] = counter
                count_labels[dna_string[i:i+4]] = counter_label
            self.wait(1)
            br = brace_target

In [386]:
%manim Features

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

# Abundances

Note: ChatGPT wrote this part with me just giving prompts like "manim: generate random circles of three different colors, that first are clustered and then randomly shuffled". Works like a charm

In [606]:
from manim import *
import numpy as np
import random

class Abundances(Scene):
    
    def construct(self):
        # Number of circles for each color
        N_BLUE = 4
        N_GREEN = 7
        N_RED = 2

        # Circle properties
        circle_radius = 0.3  # Make the circles smaller
        stroke_width = 8  # Thinner line for rings
        min_distance = circle_radius * 2 + 0.1  # Ensures no touching, including buffer
        x_lim = 6  # Horizontal limit to ensure circles stay on the screen
        y_lim = 3  # Vertical limit to ensure circles stay on the screen

        # Create circles for each color
        blue_circles = [Circle(radius=circle_radius, color=BLUE, fill_opacity=0, stroke_width=stroke_width) for _ in range(N_BLUE)]
        green_circles = [Circle(radius=circle_radius, color=GREEN, fill_opacity=0, stroke_width=stroke_width) for _ in range(N_GREEN)]
        red_circles = [Circle(radius=circle_radius, color=RED, fill_opacity=0, stroke_width=stroke_width) for _ in range(N_RED)]

        # Combine all circles
        all_circles = blue_circles + green_circles + red_circles

        # Step 1: Create rectangles for clustering
        blue_rect = Rectangle(width=3, height=2, color=BLUE).shift(LEFT * 4)  # Blue rectangle on the left
        green_rect = Rectangle(width=3, height=2, color=GREEN).shift(ORIGIN)  # Green rectangle in the center
        red_rect = Rectangle(width=3, height=2, color=RED).shift(RIGHT * 4)  # Red rectangle on the right

        blue_text = Text('Strain 1', font_size=50, color=BLUE).next_to(blue_rect, UP)
        green_text = Text('Strain 2', font_size=50, color=GREEN).next_to(green_rect, UP)
        red_text = Text('Strain 3', font_size=50, color=RED).next_to(red_rect, UP)
        # Animate creation of rectangles
        self.play(Create(blue_rect), Create(green_rect), Create(red_rect), Write(blue_text), Write(green_text), Write(red_text))

        # Step 2: Arrange circles inside rectangles, sampling from a Gaussian distribution
        blue_positions = self.generate_positions_from_gaussian_in_rect(blue_rect, N_BLUE, circle_radius)
        green_positions = self.generate_positions_from_gaussian_in_rect(green_rect, N_GREEN, circle_radius)
        red_positions = self.generate_positions_from_gaussian_in_rect(red_rect, N_RED, circle_radius)

        # Move circles to their initial positions inside the rectangles
        for i, circle in enumerate(blue_circles):
            circle.move_to(blue_positions[i])
        for i, circle in enumerate(green_circles):
            circle.move_to(green_positions[i])
        for i, circle in enumerate(red_circles):
            circle.move_to(red_positions[i])

        # Animate creation of all circles in their clusters
        self.play(*[Create(circle) for circle in all_circles])

        # Pause for a moment to show the clusters inside rectangles
        self.wait(1)

        # Step 3: Animate rectangles disappearing
        self.play(FadeOut(blue_rect), FadeOut(green_rect), FadeOut(red_rect), Unwrite(blue_text), Unwrite(green_text), Unwrite(red_text))

        # Pause for a moment after the rectangles disappear
        self.wait(1)

        # Step 4: Generate random positions for the circles and animate the shuffle
        random_positions = self.generate_non_overlapping_positions(len(all_circles), np.array([0, 0]), 2, min_distance, x_lim, y_lim)

        # Animate the shuffle of circles to their new random positions
        self.play(*[circle.animate.move_to(pos[0] * RIGHT + pos[1] * UP) for circle, pos in zip(all_circles, random_positions)], run_time=2)

        # Wait for a while to observe the shuffled result
        self.wait(2)

        # Step 5: After the main scene is done, split each circle into arcs and move them apart slightly
        all_arcs = self.split_and_move_circles(all_circles, circle_radius, stroke_width)

        # Step 6: Make all arcs white and shuffle them to random non-overlapping positions
        self.play(*[arc.animate.set_color(WHITE) for arc in all_arcs])
        random_arc_positions = self.generate_non_overlapping_positions(len(all_arcs), np.array([0, 0]), 2, min_distance, x_lim, y_lim)
        self.play(*[arc.animate.move_to(pos[0] * RIGHT + pos[1] * UP) for arc, pos in zip(all_arcs, random_arc_positions)], run_time=2)

        # Step 7: Color the arcs one by one, then color the rest
        self.color_arcs_in_stages(all_arcs)
        self.wait(1)

        # Step 8: After coloring, move and tightly pack all arcs to the left side of the screen
        self.pack_arcs_on_left(all_arcs)
        self.wait(2)

    def generate_positions_from_gaussian_in_rect(self, rectangle, n_circles, circle_radius):
        """
        Generate positions within the given rectangle, sampling from a Gaussian distribution.
        Ensure that the circles remain within the rectangle.
        """
        rect_width, rect_height = rectangle.width, rectangle.height
        positions = []
        while len(positions) < n_circles:
            # Sample from a Gaussian distribution centered on the rectangle's center
            new_pos = np.random.normal([0, 0], [rect_width / 6, rect_height / 6], 2)  # Scale to fit inside rect
            
            # Ensure the new position stays inside the rectangle
            if abs(new_pos[0]) <= rect_width / 2 - circle_radius and abs(new_pos[1]) <= rect_height / 2 - circle_radius:
                positions.append(rectangle.get_center() + new_pos[0] * RIGHT + new_pos[1] * UP)

        return positions

    def generate_non_overlapping_positions(self, n, mean, std, min_distance, x_lim, y_lim):
        """
        Generates positions for n objects that do not overlap and remain within the screen boundaries,
        based on a 2D Gaussian distribution.
        """
        positions = []
        while len(positions) < n:
            new_pos = np.random.normal(mean, std, 2)
            # Ensure the new position is within screen boundaries
            if abs(new_pos[0]) < x_lim and abs(new_pos[1]) < y_lim:
                # Ensure no overlap with existing positions
                if all(np.linalg.norm(new_pos - pos) >= min_distance for pos in positions):
                    positions.append(new_pos)
        return positions

    def split_and_move_circles(self, circles, radius, stroke_width):
        """
        Split each circle into arcs, move them apart slightly to preserve the overall shape.
        Return a list of all arcs.
        """
        all_arcs = []

        for circle in circles:
            # Determine the number of arcs based on the color of the circle
            if circle.get_color() == BLUE:
                n_arcs = 2
            elif circle.get_color() == GREEN:
                n_arcs = 8
            elif circle.get_color() == RED:
                n_arcs = 4

            # Split the circle into arcs
            arcs = self.split_circle_into_arcs(circle, n_arcs)

            all_arcs.extend(arcs)

        # Animate the arcs appearing and the circles disappearing
        self.play(FadeOut(VGroup(*circles)), FadeIn(VGroup(*all_arcs)))
        self.wait(1)

        return all_arcs

    def split_circle_into_arcs(self, ring, n_arcs):
        """
        Split the circle into n arcs and position them on the perimeter.
        """
        arcs = []
        total_angle = TAU  # 360 degrees (full circle)
        arc_angle = total_angle / n_arcs  # The angle for each arc

        # Create each arc
        start_angle = 0
        for _ in range(n_arcs):
            arc = Arc(
                radius=ring.radius,
                start_angle=start_angle,
                angle=arc_angle,
                stroke_width=ring.stroke_width,
                color=ring.get_color()  # Initially, color the arcs based on the circle's color
            )

            # Position the arc on the perimeter of the ring
            mid_angle = start_angle + arc_angle / 2
            arc_center_offset = np.array([np.cos(mid_angle), np.sin(mid_angle), 0]) * ring.radius * 1.1
            arc.move_to(ring.get_center() + arc_center_offset)
            
            arcs.append(arc)
            start_angle += arc_angle  # Move to the next section of the circle

        return arcs

    def color_arcs_in_stages(self, all_arcs):
        """
        Color arcs based on their shapes (length and orientation) in stages:
        - First color one type, wait.
        - Then color a second type, wait.
        - Then color all the remaining types at once.
        """
        arc_groups = {}

        # Classify arcs by their length (based on angle) and orientation (based on start_angle)
        for arc in all_arcs:
            arc_type = (arc.angle, arc.start_angle)  # Arc type defined by its length and orientation
            if arc_type not in arc_groups:
                arc_groups[arc_type] = []
            arc_groups[arc_type].append(arc)

        # Assign unique colors to each arc type
        colors = [BLUE, RED, GREEN, YELLOW, ORANGE, PURPLE, PINK]
        type_color_map = {}
        
        # Assign each "type" a unique color
        for i, arc_type in enumerate(arc_groups.keys()):
            type_color_map[arc_type] = colors[i % len(colors)]

        # First, color one type
        first_type = list(arc_groups.keys())[0]
        self.play(*[arc.animate.set_color(type_color_map[first_type]) for arc in arc_groups[first_type]])
        self.wait(1)

        # Then, color the second type
        second_type = list(arc_groups.keys())[1]
        self.play(*[arc.animate.set_color(type_color_map[second_type]) for arc in arc_groups[second_type]])
        self.wait(1)

        # Finally, color all the remaining types at once
        remaining_arcs = [arc for arc_type, arcs in arc_groups.items() if arc_type not in [first_type, second_type] for arc in arcs]
        all_remaining_color_animations = [arc.animate.set_color(type_color_map[(arc.angle, arc.start_angle)]) for arc in remaining_arcs]
        self.play(*all_remaining_color_animations)

    def pack_arcs_on_left(self, all_arcs):
        """
        Move all arcs to the left and right side of the screen and pack them closely together,
        ensuring that arcs of the same type (length and orientation) are in the same row.
        If arcs exceed the bottom of the left side, they will be moved to the right side.
        """
        left_bound = -4  # Define the left-most boundary on the screen
        right_bound = 2  # Define the right-most boundary on the screen
        max_x = 4  # Maximum horizontal space
        bottom_bound = -3  # Define the bottom boundary of the screen
        top_bound = 3  # Define the top boundary of the screen
        row_spacing_y = 0.7  # Vertical spacing between rows
        spacing_x = 0.7  # Horizontal spacing between arcs in a row
        current_y = top_bound
        current_x = left_bound  # Start on the left side
        use_right_side = False  # Boolean to indicate whether we switch to the right side
    
        # Classify arcs by their type (length and orientation)
        arc_groups = {}
        for arc in all_arcs:
            arc_type = (arc.angle, arc.start_angle)  # Arc type defined by its length and orientation
            if arc_type not in arc_groups:
                arc_groups[arc_type] = []
            arc_groups[arc_type].append(arc)
    
        actions = []
    
        # Move each group of arcs to its own row
        for arc_type, arcs in arc_groups.items():
            # Check if we need to switch to the right side (if we've reached the bottom)
            if current_y < bottom_bound:
                use_right_side = True
                current_x = right_bound  # Reset to the right side
                current_y = top_bound  # Reset vertical position to the top
    
            # Move each arc in the current group to its position in the row
            for arc in arcs:
                new_position = np.array([current_x, current_y, 0])
                action = arc.animate.move_to(new_position)
                actions.append(action)
    
                # Adjust the x position for the next arc in the row
                current_x += spacing_x
    
            # After placing all arcs of this type, move to the next row
            current_y -= row_spacing_y
    
            # Reset x to the left or right side for the next row
            current_x = left_bound if not use_right_side else right_bound
    
        self.play(*actions)
        self.wait(1)

        # Collapse all arcs of the same type into one and display the count next to the first arc
        collapse_actions = []
        text_actions = []
        first_integer = None  # Keep track of the first integer
        first_arc = None  # Keep track of the first arc
        count_texts = []
        for arc_type, arcs in arc_groups.items():
            if len(arcs) > 1:
                first_arc = arcs[0]  # Keep the first arc as the remaining one
    
                # Move all other arcs on top of the first one
                for arc in arcs[1:]:
                    collapse_action = arc.animate.move_to(first_arc.get_center())
                    collapse_actions.append(collapse_action)
    
                # Create a label to display the count next to the first arc
                count_text = Text(str(len(arcs)), font_size=24)
                count_text.next_to(first_arc, RIGHT)  # Position the label to the right of the arc
                count_texts.append(count_text)
                text_actions.append(Write(count_text))
    
                if not first_integer:
                    first_integer = count_text
    
        # Play the collapsing actions for all arcs
        self.play(*collapse_actions)
        self.play(*text_actions)

        # Step 9: Move the first arc to the center, make it bigger, and make everything else disappear
        if first_arc and first_integer:

            # Step 10: Highlight with yellow square and show big text in the lower-right area
            yellow_square = SurroundingRectangle(first_integer, color=YELLOW)
            feature_text = Text("Feature 1: abundance vector", font_size=36).to_corner(DR)

            # Simultaneously fade in and out the yellow square and show the text
            self.play(FadeIn(yellow_square), FadeIn(feature_text))
            self.wait(1.5)
            self.play(FadeOut(yellow_square), FadeOut(feature_text))

        self.play(
            first_arc.animate.move_to(ORIGIN).scale(5),  # Move the first arc to the center and scale it up
            FadeOut(VGroup(*[arc for arc_type, arcs in arc_groups.items() for arc in arcs if arc != first_arc])),  # Fade out all other arcs
            Unwrite(VGroup(*count_texts))  # Fade out all other text except the first integer
        )

        arc = first_arc
        
        t2c_map = {'A': RED, 'C': BLUE, 'G': YELLOW, 'T': GREEN}
        DNA_STRING = "CTACTACG"

        # Step 2: Display DNA string inside the arc, following its curve
        dna_string = Text(DNA_STRING, font_size=50, t2c=t2c_map)
        # Place the text on the arc's curve
        num_letters = len(dna_string)
        for i, letter in enumerate(dna_string):
            # Calculate the proportion along the arc and its angle for proper positioning
            proportion = i / (num_letters - 1)
            position_on_arc = arc.point_from_proportion(proportion)
            tangent_angle = arc.get_angle() * proportion + arc.start_angle

            # Move each letter to its correct position and rotate it accordingly
            letter.move_to(position_on_arc)
            letter.rotate(tangent_angle - PI/2, about_point=letter.get_center())  # Adjust rotation for upright orientation

        # Animate the text appearing on the arc
        self.play(FadeIn(dna_string))

        # Step 3: Remove the arc while keeping the text in place
        self.play(FadeOut(arc))

        # Step 4: Flatten the curved text into a straight line
        flattened_text = Text(DNA_STRING, font_size=50, t2c=t2c_map).move_to(arc.get_center()).shift(UP*2)
        self.play(Transform(dna_string, flattened_text))

        # Wait for a moment to observe the result
        self.wait(2)

        dna_text = flattened_text

        br = Brace(dna_text[0:4], direction=DOWN)
        self.play(FadeIn(dna_text), Write(br))
        self.wait(1)
        counters = {}
        counters[DNA_STRING[0:4]] = 1
        counter_label = Text(DNA_STRING[0:4], font_size=50, t2c=t2c_map)
        counter = Integer(1, color=WHITE, font_size=50).next_to(counter_label, RIGHT*2)
        self.play(Write(counter_label), Write(counter))
        prev_line = counter_label
        count_integers = {DNA_STRING[0:4]: counter}
        count_labels = {DNA_STRING[0:4]: counter_label}
        for i in range(1, len(DNA_STRING)-3):
            brace_target = Brace(dna_text[i:i+4], direction=DOWN)  # Attach brace to word
            self.play(
                ReplacementTransform(br, brace_target),
                run_time=1
            )
            if DNA_STRING[i:i+4] in counters:
                counters[DNA_STRING[i:i+4]] += 1
                self.play(count_integers[DNA_STRING[i:i+4]].animate.set_value(counters[DNA_STRING[i:i+4]]))
            else:
                counters[DNA_STRING[i:i+4]] = 1
                counter_label = Text(DNA_STRING[i:i+4], font_size=50, t2c=t2c_map).next_to(prev_line, DOWN)
                counter = Integer(1, color=WHITE, font_size=50).next_to(counter_label, RIGHT*2)
                prev_line = counter_label
                self.play(Write(counter_label), Write(counter))
                count_integers[DNA_STRING[i:i+4]] = counter
                count_labels[DNA_STRING[i:i+4]] = counter_label
            self.wait(1)
            br = brace_target
        yellow_square = SurroundingRectangle(VGroup(*list(count_integers.values())), color=YELLOW)
        feature_text = Text("Feature 2: tetranucleotide frequencies", font_size=36).to_corner(DR)

        # Simultaneously fade in and out the yellow square and show the text
        self.play(FadeIn(yellow_square), FadeIn(feature_text))
        self.wait(1.5)
        self.play(FadeOut(yellow_square), FadeOut(feature_text))

In [607]:
%manim Abundances

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

  tangent_angle = arc.get_angle() * proportion + arc.start_angle


                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

                                                                                                                                                                                  

# Neural network

In [603]:
class NeuralNetworkMobject(VGroup):
    CONFIG = {
        "neuron_radius": 0.15,
        "neuron_to_neuron_buff": MED_SMALL_BUFF,
        "layer_to_layer_buff": LARGE_BUFF,
        "output_neuron_color": WHITE,
        "input_neuron_color": WHITE,
        "hidden_layer_neuron_color": WHITE,
        "neuron_stroke_width": 2,
        "neuron_fill_color": GREEN,
        "edge_color": LIGHT_GREY,
        "edge_stroke_width": 0.5,
        "edge_propogation_color": YELLOW,
        "edge_propogation_time": 1,
        "max_shown_neurons": 16,
        "brace_for_large_layers": True,
        "average_shown_activation_of_large_layer": True,
        "include_output_labels": False,
        "arrow": False,
        "arrow_tip_size": 0.1,
        "left_size": 1,
        "neuron_fill_opacity": 1
    }
    # Constructor with parameters of the neurons in a list
    def __init__(self, neural_network, *args, **kwargs):
        VGroup.__init__(self, *args, **kwargs)
        for key in self.CONFIG:
            setattr(self, key, self.CONFIG[key])
        self.layer_sizes = neural_network
        self.add_neurons()
        self.add_edges()
        self.add_to_back(self.layers)

    # Helper method for constructor
    def add_neurons(self):
        layers = VGroup(*[
            self.get_layer(size, index)
            for index, size in enumerate(self.layer_sizes)
        ])
        layers.arrange_submobjects(RIGHT, buff=self.layer_to_layer_buff)
        self.layers = layers
        if self.include_output_labels:
            self.label_outputs_text()
    # Helper method for constructor
    def get_nn_fill_color(self, index):
        if index == -1 or index == len(self.layer_sizes) - 1:
            return self.output_neuron_color
        if index == 0:
            return self.input_neuron_color
        else:
            return self.hidden_layer_neuron_color
    # Helper method for constructor
    def get_layer(self, size, index=-1):
        layer = VGroup()
        n_neurons = size
        if n_neurons > self.max_shown_neurons:
            n_neurons = self.max_shown_neurons
        neurons = VGroup(*[
            Circle(
                radius=self.neuron_radius,
                stroke_color=self.get_nn_fill_color(index),
                stroke_width=self.neuron_stroke_width,
                fill_color=BLACK,
                fill_opacity=self.neuron_fill_opacity,
            )
            for x in range(n_neurons)
        ])
        neurons.arrange_submobjects(
            DOWN, buff=self.neuron_to_neuron_buff
        )
        for neuron in neurons:
            neuron.edges_in = VGroup()
            neuron.edges_out = VGroup()
        layer.neurons = neurons
        layer.add(neurons)

        if size > n_neurons:
            dots = MathTex("\\vdots")
            dots.move_to(neurons)
            VGroup(*neurons[:len(neurons) // 2]).next_to(
                dots, UP, MED_SMALL_BUFF
            )
            VGroup(*neurons[len(neurons) // 2:]).next_to(
                dots, DOWN, MED_SMALL_BUFF
            )
            layer.dots = dots
            layer.add(dots)
            if self.brace_for_large_layers:
                brace = Brace(layer, LEFT)
                brace_label = brace.get_tex(str(size))
                layer.brace = brace
                layer.brace_label = brace_label
                layer.add(brace, brace_label)

        return layer
    # Helper method for constructor
    def add_edges(self):
        self.edge_groups = VGroup()
        for l1, l2 in zip(self.layers[:-1], self.layers[1:]):
            edge_group = VGroup()
            for n1, n2 in it.product(l1.neurons, l2.neurons):
                edge = self.get_edge(n1, n2)
                edge_group.add(edge)
                n1.edges_out.add(edge)
                n2.edges_in.add(edge)
            self.edge_groups.add(edge_group)
        self.add_to_back(self.edge_groups)
    # Helper method for constructor
    def get_edge(self, neuron1, neuron2):
        if self.arrow:
            return Arrow(
                neuron1.get_center(),
                neuron2.get_center(),
                buff=self.neuron_radius,
                stroke_color=self.edge_color,
                stroke_width=self.edge_stroke_width,
                tip_length=self.arrow_tip_size
            )
        return Line(
            neuron1.get_center(),
            neuron2.get_center(),
            buff=self.neuron_radius,
            stroke_color=self.edge_color,
            stroke_width=self.edge_stroke_width,
        )
    
    # Labels each input neuron with a char l or a LaTeX character
    def label_inputs(self, l):
        self.output_labels = VGroup()
        for n, neuron in enumerate(self.layers[0].neurons):
            label = MathTex(rf"{l}_"+"{"+f"{n + 1}"+"}")
            label.set_height(0.3 * neuron.get_height())
            label.move_to(neuron)
            self.output_labels.add(label)
        self.add(self.output_labels)

    # Labels each output neuron with a char l or a LaTeX character
    def label_outputs(self, l):
        self.output_labels = VGroup()
        for n, neuron in enumerate(self.layers[-1].neurons):
            label = MathTex(rf"{l}_"+"{"+f"{n + 1}"+"}")
            label.set_height(0.4 * neuron.get_height())
            label.move_to(neuron)
            self.output_labels.add(label)
        self.add(self.output_labels)

    # Labels each neuron in the output layer with text according to an output list
    def label_outputs_text(self, outputs):
        self.output_labels = VGroup()
        for n, neuron in enumerate(self.layers[-1].neurons):
            label = MathTex(outputs[n])
            label.set_height(0.75*neuron.get_height())
            label.move_to(neuron)
            label.shift((neuron.get_width() + label.get_width()/2)*RIGHT)
            self.output_labels.add(label)
        self.add(self.output_labels)

    # Labels the hidden layers with a char l or a LaTeX character
    def label_hidden_layers(self, l):
        self.output_labels = VGroup()
        for layer in self.layers[1:-1]:
            for n, neuron in enumerate(layer.neurons):
                label = MathTex(fr"{l}_{n + 1}")
                label.set_height(0.4 * neuron.get_height())
                label.move_to(neuron)
                self.output_labels.add(label)
        self.add(self.output_labels)