In [None]:
# Install if needed
!pip install manim
!pip install transformers torch matplotlib seaborn

from manim import *
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
def extract_attention(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    with torch.no_grad():
        outputs = model(**inputs)
    attentions = outputs.attentions
    return tokens, attentions, inputs["input_ids"]


In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, output_attentions=True)
model = AutoModel.from_pretrained(model_name, config=config)
model.eval()


In [None]:
text = "The cat that I love sat on the mat."
tokens, attentions, input_ids = extract_attention(text, model, tokenizer)

# Choose layer/head you want to animate
layer_idx = 2
head_idx = 5
matrix = attentions[layer_idx][0, head_idx].cpu().numpy()

# Save for Manim
np.savez("attention_data.npz", tokens=np.array(tokens), matrix=matrix,
         layer_idx=layer_idx, head_idx=head_idx, text=text)
print("Saved data for Manim.")


In [None]:
plt.figure(figsize=(8,6))
sns.heatmap(matrix, xticklabels=tokens, yticklabels=tokens,
            cmap='Blues', annot=True, fmt='.2f')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.title(f"Layer {layer_idx}, Head {head_idx}")
plt.tight_layout()
plt.show()


In [None]:
%%manim -ql -v WARNING AttentionHeadScene

from manim import *
import numpy as np

class AttentionHeadScene(Scene):
    def construct(self):
        data = np.load("attention_data.npz", allow_pickle=True)

        # Convert NumPy strings → Python strings
        raw_tokens = data["tokens"]
        tokens = [str(t) for t in raw_tokens.tolist()]
        matrix = data["matrix"]
        layer_idx = int(data.get("layer_idx", -1))
        head_idx = int(data.get("head_idx", -1))
        text = str(data.get("text", ""))

        seq_len = len(tokens)

        # --- Title ----------------------------------------------------------
        title = Text(f"Layer {layer_idx}, Head {head_idx}", font_size=36)
        title.to_edge(UP)
        self.add(title)

        # --- Token labels (top & left) -------------------------------------
        col_labels = VGroup()
        row_labels = VGroup()
        cell_size = 0.6  # shrink a bit vs before

        for j, tok in enumerate(tokens):
            t = Text(tok, font_size=20)
            t.next_to(ORIGIN, UP, buff=0.3)
            t.shift(RIGHT * (j - (seq_len - 1)/2) * (cell_size + 0.05) + DOWN*0.5)
            col_labels.add(t)

        for i, tok in enumerate(tokens):
            t = Text(tok, font_size=20)
            # Align left of grid
            t.shift(
                LEFT * ((seq_len/2) * (cell_size + 0.05) + 1.2)  # left of grid
                + DOWN * (i - (seq_len - 1)/2) * (cell_size + 0.05)
            )
            row_labels.add(t)

        self.add(col_labels, row_labels)

        # --- Heatmap cells with numbers ------------------------------------
        cells = VGroup()
        min_val = 0.0
        max_val = 1.0  # attention is already in [0,1]

        def value_to_color(v):
            # 0 → light, 1 → dark blue
            alpha = (v - min_val) / (max_val - min_val + 1e-8)
            return interpolate_color(WHITE, BLUE, alpha)

        for i in range(seq_len):
            for j in range(seq_len):
                v = float(matrix[i, j])

                # cell position (grid centered around origin)
                x = (j - (seq_len - 1)/2) * (cell_size + 0.05)
                y = (-(i - (seq_len - 1)/2)) * (cell_size + 0.05)
                pos = np.array([x, y, 0])

                sq = Square(side_length=cell_size)
                sq.set_fill(value_to_color(v), opacity=0.9)
                sq.set_stroke(color=BLACK, width=0.5)
                sq.move_to(pos)

                # numeric label
                num = Text(f"{v:.2f}", font_size=18)
                num.move_to(pos)

                # choose text color for contrast
                if v > 0.6:
                    num.set_color(WHITE)
                else:
                    num.set_color(BLACK)

                cells.add(VGroup(sq, num))

        self.play(FadeIn(cells), run_time=1.0)

        # --- Colorbar legend on the right ----------------------------------
        # simple vertical gradient bar + ticks 0, 0.5, 1
        bar_height = seq_len * (cell_size + 0.05)
        bar_width = 0.3

        # anchor bar to right of grid
        bar_x = (seq_len/2) * (cell_size + 0.05) + 1.0
        bar_center = np.array([bar_x, 0, 0])

        # build bar as multiple thin rectangles
        n_steps = 20
        bar_rects = VGroup()
        for k in range(n_steps):
            frac = k / (n_steps - 1)
            r = Rectangle(
                height=bar_height / n_steps,
                width=bar_width,
            )
            r.set_fill(interpolate_color(WHITE, BLUE, frac), opacity=0.9)
            r.set_stroke(width=0)
            y = (frac - 0.5) * bar_height
            r.move_to(bar_center + np.array([0, y, 0]))
            bar_rects.add(r)

        # ticks & labels
        ticks = VGroup()
        labels = VGroup()
        for val, frac in [(0.0, 0.0), (0.5, 0.5), (1.0, 1.0)]:
            y = (frac - 0.5) * bar_height
            tick = Line(
                bar_center + np.array([-bar_width/2, y, 0]),
                bar_center + np.array([-bar_width/2 - 0.1, y, 0]),
                stroke_width=2
            )
            lbl = Text(f"{val:.1f}", font_size=20)
            lbl.next_to(tick, LEFT, buff=0.1)
            ticks.add(tick)
            labels.add(lbl)

        colorbar_group = VGroup(bar_rects, ticks, labels)
        self.play(FadeIn(colorbar_group), run_time=0.8)

        # --- Optional: subtitle with sentence ------------------------------
        if text:
            subtitle = Text(text, font_size=24).next_to(title, DOWN, buff=0.3)
            self.play(Write(subtitle), run_time=0.8)

        self.wait(2)


The Manim Jupyter extension (manim.utils.ipython_magic) takes the rendered video and inserts it directly into the output cell using HTML <video> tags.

In [None]:
%%manim -ql -v WARNING TokenGraphScene

from manim import *
import numpy as np

class TokenGraphScene(Scene):
    def construct(self):
        # Load single-layer/head attention data
        data = np.load("attention_data.npz", allow_pickle=True)
        raw_tokens = data["tokens"]
        tokens = [str(t) for t in raw_tokens.tolist()]
        matrix = data["matrix"]
        layer_idx = int(data.get("layer_idx", -1))
        head_idx = int(data.get("head_idx", -1))
        text = str(data.get("text", ""))

        seq_len = len(tokens)

        # Title
        title = Text(f"Token Graph — Layer {layer_idx}, Head {head_idx}", font_size=36)
        title.to_edge(UP)
        self.add(title)

        if text:
            subtitle = Text(text, font_size=24).next_to(title, DOWN, buff=0.3)
            self.play(Write(subtitle), run_time=0.8)

        # Place tokens on a circle
        radius = 3.0
        angle_step = TAU / seq_len
        token_nodes = []
        for i, tok in enumerate(tokens):
            angle = i * angle_step
            pos = np.array([radius * np.cos(angle), radius * np.sin(angle), 0])
            label = Text(tok, font_size=24).move_to(pos)
            token_nodes.append(label)

        self.play(*[FadeIn(node, shift=0.2*OUT) for node in token_nodes], run_time=1.2)

        # Build list of edges (i -> j) sorted by attention weight
        edges = []
        for i in range(seq_len):
            for j in range(seq_len):
                if i == j:
                    continue
                w = float(matrix[i, j])
                if w > 0.05:  # small threshold to avoid clutter
                    edges.append((i, j, w))
        edges.sort(key=lambda x: x[2], reverse=True)

        # Animate edges from strongest to weakest
        edge_mobs = []
        max_w = max([e[2] for e in edges]) if edges else 1.0

        for i, j, w in edges:
            start = token_nodes[i].get_center()
            end = token_nodes[j].get_center()
            alpha = w / (max_w + 1e-8)

            # curved arrow for aesthetics
            arrow = CurvedArrow(
                start,
                end,
                angle=0.3,
                stroke_width=2 + 5 * alpha,
                color=interpolate_color(GRAY, BLUE, alpha),
                tip_length=0.2,
            )
            edge_mobs.append(arrow)
            self.play(Create(arrow), run_time=0.1)

        self.wait(2)
        self.play(*[FadeOut(m) for m in edge_mobs + token_nodes], run_time=1.0)
        self.wait(0.5)


Bertsviz already has great token graphs, but we can use manim if we want to animate it, make it look better and more customized. 

# Embedding Visualizations


In [None]:
import sys, subprocess

packages = [
    "manim",
    "transformers",
    "torch",
    "matplotlib",
    "seaborn",
    "umap-learn",
    "scikit-learn",
]

subprocess.check_call([sys.executable, "-m", "pip", "install", *packages])
print("Dependencies installed successfully!")


In [None]:
import torch
import numpy as np

from transformers import AutoTokenizer, AutoModel, AutoConfig

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import umap

import matplotlib.pyplot as plt
import seaborn as sns

# Jupyter config
plt.style.use("default")
sns.set_palette("husl")
%matplotlib inline

# Load manim IPython magic
%load_ext manim

print("Libraries imported and manim extension loaded!")


In [None]:
model_name = "distilbert-base-uncased"

print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
model = AutoModel.from_pretrained(model_name, config=config)

model.eval()

print(f"Model loaded: {model_name}")
print(f"Number of layers: {config.n_layers}")
print(f"Hidden size: {config.dim}")
print(f"Number of attention heads per layer: {config.n_heads}")
print(f"Vocabulary size: {config.vocab_size}")


In [None]:
def extract_layerwise_embeddings(text, model, tokenizer):
    """
    Extract embeddings from all layers for a given text.
    
    Returns:
        tokens: list of token strings
        layer_embeddings: list of (seq_len, hidden_dim) arrays
        input_ids: token ids tensor
    """
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    print("inputs", inputs)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    with torch.no_grad():
        outputs = model(**inputs)

    # outputs.hidden_states: (embedding_output, layer_1, ..., layer_n)
    hidden_states = outputs.hidden_states
    layer_embeddings = [h[0].cpu().numpy() for h in hidden_states]  # drop batch dim

    return tokens, layer_embeddings, inputs["input_ids"]

# quick sanity check
sample_text = "I will present the present at the meeting."
tokens_test, layer_emb_test, _ = extract_layerwise_embeddings(sample_text, model, tokenizer)
print("Tokens:", tokens_test)
print("Num layers (incl. embedding layer 0):", len(layer_emb_test))
print("Shape per layer:", layer_emb_test[0].shape)


In [None]:
def reduce_dimensions(embeddings, method='umap', n_components=2, random_state=42):
    """
    Reduce embedding dimensionality using UMAP, t-SNE, or PCA.
    """
    if method == 'umap':
        reducer = umap.UMAP(n_components=n_components, random_state=random_state, n_neighbors=15)
    elif method == 'tsne':
        reducer = TSNE(n_components=n_components, random_state=random_state, perplexity=30)
    elif method == 'pca':
        reducer = PCA(n_components=n_components, random_state=random_state)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return reducer.fit_transform(embeddings)


def prepare_embeddings_for_reduction(layer_embeddings, exclude_special_tokens=True):
    """
    Stack embeddings from all layers into a single array for DR.

    Returns:
        all_embeddings: (N, hidden_dim)
        layer_labels:   (N,)  -> which layer it came from
        token_positions:(N,)  -> which token index it came from
    """
    all_embeddings = []
    layer_labels = []
    token_positions = []

    for layer_idx, embeddings in enumerate(layer_embeddings):
        seq_len = embeddings.shape[0]

        if exclude_special_tokens:
            start_idx = 1
            end_idx = seq_len - 1
        else:
            start_idx = 0
            end_idx = seq_len

        for pos in range(start_idx, end_idx):
            all_embeddings.append(embeddings[pos])
            layer_labels.append(layer_idx)
            token_positions.append(pos)

    return np.array(all_embeddings), np.array(layer_labels), np.array(token_positions)

print("Helper functions ready: reduce_dimensions, prepare_embeddings_for_reduction")


In [None]:
import numpy as np
from sklearn.decomposition import PCA

def save_two_token_evolution_3d_for_manim(
    sentence,
    token_of_interest,
    filename="two_token_evolution_3d.npz"
):
    """
    Compute layer-wise embeddings for TWO occurrences of a token,
    run PCA→3D jointly, and save everything Manim needs.

    - sentence: the input text
    - token_of_interest: e.g. "present"
    - filename: npz file to save
    """
    tokens, layer_embeddings, _ = extract_layerwise_embeddings(sentence, model, tokenizer)

    # Find all indices of the token in the tokenized sequence
    indices = [i for i, t in enumerate(tokens) if t == token_of_interest]
    if len(indices) < 2:
        raise ValueError(
            f"Need at least 2 occurrences of '{token_of_interest}' in tokens, "
            f"but found {len(indices)}.\nTokens: {tokens}"
        )

    idx1, idx2 = indices[0], indices[1]

    # Stack embeddings across layers for each occurrence
    # token_evolution_1: (num_layers, hidden_dim)
    # token_evolution_2: (num_layers, hidden_dim)
    token_evolution_1 = np.array([layer[idx1] for layer in layer_embeddings])
    token_evolution_2 = np.array([layer[idx2] for layer in layer_embeddings])

    num_layers, hidden_dim = token_evolution_1.shape

    # PCA → 3D, fit jointly on both trajectories so they share the same 3D space
    stacked = np.vstack([token_evolution_1, token_evolution_2])  # (2*num_layers, hidden_dim)
    pca = PCA(n_components=3)
    stacked_3d = pca.fit_transform(stacked)  # (2*num_layers, 3)

    coords1_3d = stacked_3d[:num_layers]       # (num_layers, 3)
    coords2_3d = stacked_3d[num_layers:]       # (num_layers, 3)

    np.savez(
        filename,
        coords1=coords1_3d,
        coords2=coords2_3d,
        sentence=sentence,
        token=token_of_interest,
        idx1=idx1,
        idx2=idx2,
        num_layers=num_layers,
        tokens=np.array(tokens),
    )
    print(f"Saved 3D evolution for two '{token_of_interest}' occurrences to {filename}")
    print("Tokens:", tokens)
    print(f"Occurrences used: indices {idx1} and {idx2}")

# EXAMPLE: your sentence with two 'present's
sentence = "I will present the present at the meeting."
save_two_token_evolution_3d_for_manim(sentence, "present")


In [None]:
%%manim -ql -v WARNING TokenEvolutionScene

from manim import *
import numpy as np

class TokenEvolutionScene(Scene):
    def construct(self):
        data = np.load("token_evolution.npz", allow_pickle=True)

        points = data["points"]               # (num_layers, 2)
        sentence = str(data["sentence"])
        token = str(data["token"])
        occurrence = int(data["occurrence"])
        num_layers = int(data["num_layers"])

        # Normalize coordinates for nicer layout
        xs = points[:, 0]
        ys = points[:, 1]
        max_range = max(xs.max() - xs.min(), ys.max() - ys.min()) + 1e-6
        points_norm = np.column_stack([
            (xs - xs.mean()) / max_range * 8,
            (ys - ys.mean()) / max_range * 4,
        ])

        # Title & subtitle
        title = Text(f"Token evolution across layers: '{token}'", font_size=36)
        title.to_edge(UP)
        self.add(title)

        subtitle = Text(
            f"Sentence: {sentence}  (occurrence {occurrence})",
            font_size=24
        ).next_to(title, DOWN, buff=0.3)
        self.play(Write(subtitle), run_time=0.8)

        # Axes
        axes = Axes(
            x_range=[-5, 5, 1],
            y_range=[-3, 3, 1],
            x_length=10,
            y_length=6,
            tips=False
        )
        axes.move_to(ORIGIN)
        self.play(Create(axes), run_time=0.7)

        # All layer points as faint background
        all_dots = VGroup()
        for i, (x, y) in enumerate(points_norm):
            dot = Dot(axes.c2p(x, y), radius=0.06, color=GRAY)
            label = Text(f"L{i}", font_size=18).next_to(dot, UP, buff=0.1)
            all_dots.add(VGroup(dot, label))
        self.play(FadeIn(all_dots, run_time=0.5))

        # Moving highlighted dot
        current_dot = Dot(axes.c2p(*points_norm[0]), radius=0.12, color=YELLOW)
        layer_label = Text("Layer 0", font_size=28).to_corner(UL).shift(DOWN*0.3)
        self.play(FadeIn(current_dot), FadeIn(layer_label), run_time=0.7)

        # Path object
        path = VMobject(stroke_color=YELLOW, stroke_width=4)
        path.set_points_as_corners([axes.c2p(*points_norm[0])])
        self.add(path)

        for layer_idx in range(1, num_layers):
            new_point = axes.c2p(*points_norm[layer_idx])
            new_layer_label = Text(f"Layer {layer_idx}", font_size=28).to_corner(UL).shift(DOWN*0.3)

            self.play(
                current_dot.animate.move_to(new_point),
                ReplacementTransform(layer_label, new_layer_label),
                run_time=0.8,
            )
            layer_label = new_layer_label

        self.wait(2)


In [None]:
%%manim -qk -v WARNING TwoTokenEvolution3DVectorsScene

from manim import *
import numpy as np

class TwoTokenEvolution3DVectorsScene(ThreeDScene):
    def construct(self):
        data = np.load("two_token_evolution_3d.npz", allow_pickle=True)

        coords1 = data["coords1"]    # (num_layers, 3)
        coords2 = data["coords2"]    # (num_layers, 3)
        num_layers = int(data["num_layers"])
        token = str(data["token"])
        sentence = str(data["sentence"])

        # -------- Normalize coords for nicer view --------
        all_coords = np.vstack([coords1, coords2])
        max_abs = np.max(np.abs(all_coords)) + 1e-6
        scale = 4 / max_abs
        coords1 = coords1 * scale
        coords2 = coords2 * scale

        # -------- Titles (fixed in frame) --------
        title = Text(
            f"3D Vector Evolution of Two '{token}' Occurrences",
            font_size=32
        ).to_edge(UP)
        subtitle = Text(
            f"Sentence: {sentence}",
            font_size=20
        ).next_to(title, DOWN, buff=0.2)

        self.add_fixed_in_frame_mobjects(title, subtitle)

        # -------- Camera + axes --------
        self.set_camera_orientation(phi=65 * DEGREES, theta=30 * DEGREES)
        self.begin_ambient_camera_rotation(rate=0.10)

        axes = ThreeDAxes(
            x_range=[-4, 4, 2],
            y_range=[-4, 4, 2],
            z_range=[-4, 4, 2],
            x_length=8,
            y_length=8,
            z_length=8,
        )
        self.play(Create(axes), run_time=1.5)

        origin = axes.c2p(0, 0, 0)

        # -------- Layer tracker --------
        layer_tracker = ValueTracker(0)

        # -------- Arrows = token vectors from origin --------
        vector1 = always_redraw(
            lambda: Arrow3D(
                start=origin,
                end=axes.c2p(*coords1[int(layer_tracker.get_value())]),
                color=BLUE,
                thickness=0.03,
            )
        )
        vector2 = always_redraw(
            lambda: Arrow3D(
                start=origin,
                end=axes.c2p(*coords2[int(layer_tracker.get_value())]),
                color=RED,
                thickness=0.03,
            )
        )

        # Optional: spheres at the tips
        tip1 = always_redraw(
            lambda: Dot3D(
                axes.c2p(*coords1[int(layer_tracker.get_value())]),
                radius=0.08,
                color=BLUE,
            )
        )
        tip2 = always_redraw(
            lambda: Dot3D(
                axes.c2p(*coords2[int(layer_tracker.get_value())]),
                radius=0.08,
                color=RED,
            )
        )

        self.play(FadeIn(vector1), FadeIn(vector2), FadeIn(tip1), FadeIn(tip2), run_time=1.0)

        # -------- Layer label (screen-fixed) --------
        layer_label = always_redraw(
            lambda: Text(
                f"Layer {int(layer_tracker.get_value())}",
                font_size=26
            ).to_corner(UL).shift(DOWN * 0.3)
        )
        self.add_fixed_in_frame_mobjects(layer_label)

        # -------- Animate layer-by-layer --------
        for i in range(1, num_layers):
            self.play(
                layer_tracker.animate.set_value(i),
                run_time=0.8,
            )

        self.wait(2)


In [None]:
def save_layer_umap_for_manim(
    sentence,
    filename="layer_umap_for_manim.npz",
    exclude_special_tokens=True,
):
    """
    Compute layer-wise embeddings for a sentence, UMAP→2D for all
    (layer, token) embeddings, reshape to [num_layers, num_tokens, 2]
    for animation.
    """
    tokens, layer_embeddings, _ = extract_layerwise_embeddings(sentence, model, tokenizer)

    all_embeddings, layer_labels, token_positions = prepare_embeddings_for_reduction(
        layer_embeddings,
        exclude_special_tokens=exclude_special_tokens,
    )

    embeddings_2d = reduce_dimensions(all_embeddings, method="umap", n_components=2)

    num_layers = len(layer_embeddings)
    seq_len = layer_embeddings[0].shape[0]

    if exclude_special_tokens:
        start_idx = 1
        end_idx = seq_len - 1
    else:
        start_idx = 0
        end_idx = seq_len

    num_tokens = end_idx - start_idx

    # Reshape: [num_layers, num_tokens, 2]
    coords = embeddings_2d.reshape(num_layers, num_tokens, 2)
    tokens_no_special = tokens[start_idx:end_idx]

    np.savez(
        filename,
        coords=coords,
        tokens=tokens_no_special,
        num_layers=num_layers,
        num_tokens=num_tokens,
        sentence=sentence,
    )
    print(f"Saved layer-wise UMAP coords to {filename}")
    print(f"Num layers: {num_layers}, num tokens (no specials): {num_tokens}")
    print("Tokens (no specials):", tokens_no_special)

# EXAMPLE: same sentence
sentence = "I will present the present at the meeting."
save_layer_umap_for_manim(sentence)


In [None]:
%%manim -ql -v WARNING LayerwiseEmbeddingsScene

from manim import *
import numpy as np

class LayerwiseEmbeddingsScene(Scene):
    def construct(self):
        data = np.load("layer_umap_for_manim.npz", allow_pickle=True)
        coords = data["coords"]          # (num_layers, num_tokens, 2)
        tokens = [str(t) for t in data["tokens"].tolist()]
        num_layers = int(data["num_layers"])
        num_tokens = int(data["num_tokens"])
        sentence = str(data["sentence"])

        # Normalize coords
        xs = coords[..., 0].flatten()
        ys = coords[..., 1].flatten()
        max_range = max(xs.max() - xs.min(), ys.max() - ys.min()) + 1e-6
        coords_norm = coords.copy()
        coords_norm[..., 0] = (coords[..., 0] - xs.mean()) / max_range * 8
        coords_norm[..., 1] = (coords[..., 1] - ys.mean()) / max_range * 4

        # Title + subtitle
        title = Text("Layer-wise token embeddings (UMAP)", font_size=36)
        title.to_edge(UP)
        self.add(title)

        subtitle = Text(
            f"Sentence: {sentence}",
            font_size=24
        ).next_to(title, DOWN, buff=0.3)
        self.play(Write(subtitle), run_time=0.8)

        # Axes
        axes = Axes(
            x_range=[-5, 5, 1],
            y_range=[-3, 3, 1],
            x_length=10,
            y_length=6,
            tips=False
        )
        axes.move_to(ORIGIN)
        self.play(Create(axes), run_time=0.7)

        # Initial positions (layer 0)
        dots = VGroup()
        labels = VGroup()
        for i in range(num_tokens):
            x, y = coords_norm[0, i]
            dot = Dot(axes.c2p(x, y), radius=0.09, color=BLUE)
            label = Text(tokens[i], font_size=18).next_to(dot, UP, buff=0.1)
            dots.add(dot)
            labels.add(label)

        layer_label = Text("Layer 0", font_size=28).to_corner(UL).shift(DOWN*0.3)
        self.play(FadeIn(dots), FadeIn(labels), FadeIn(layer_label), run_time=0.8)

        # Animate movement through layers
        for layer_idx in range(1, num_layers):
            new_positions = coords_norm[layer_idx]
            new_layer_label = Text(f"Layer {layer_idx}", font_size=28).to_corner(UL).shift(DOWN*0.3)

            self.play(
                *[
                    dot.animate.move_to(axes.c2p(*new_positions[i]))
                    for i, dot in enumerate(dots)
                ],
                ReplacementTransform(layer_label, new_layer_label),
                run_time=0.9,
            )
            layer_label = new_layer_label
            self.wait(0.3)

        self.wait(2)
