# Interpreting Identity

# Setup

In [1]:
TRAIN_MODEL = True

In [2]:
DEVELOPMENT_MODE = True
# Upgrade pip
%pip install --upgrade pip

from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")
ipython.run_line_magic("pip", "install ipympl")
ipython.run_line_magic("pip", "install scipy")
ipython.run_line_magic("pip", "install manim")
ipython.run_line_magic("pip", "install torch")
ipython.run_line_magic("pip", "install numpy<2")
ipython.run_line_magic("pip", "install einops")
ipython.run_line_magic("pip", "install transformer_lens")

Note: you may need to restart the kernel to use updated packages.


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


Collecting ipympl
  Downloading ipympl-0.9.7-py3-none-any.whl.metadata (8.7 kB)
Collecting ipywidgets<9,>=7.6.0 (from ipympl)
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting matplotlib<4,>=3.5.0 (from ipympl)
  Downloading matplotlib-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy (from ipympl)
  Downloading numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting pillow (from ipympl)
  Downloading pillow-11.2.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (8.9 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets<9,>=7.6.0->ipympl)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets<9,>=7.6.0->ipympl)
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Collecting contourpy>=1.0.1 (from matplotlib<4,>=3.5.0->ipympl)
  Downloading contourpy-1.3.2-cp310-cp310-ma

In [3]:
# Import stuff
import torch
import numpy as np
import os
import tqdm.auto as tqdm
from pathlib import Path

import copy

from transformer_lens import HookedTransformer, HookedTransformerConfig, ActivationCache

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/identity.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

cuda


# Model Training

## Config

In [4]:
INPUT_DIM = 3
OUTPUT_DIM = 3
frac_train = 1

# Optimizer config
lr = 1e-3
wd = 1e-2
betas = (0.9, 0.999)

num_epochs = 10000
checkpoint_every = 500

DATA_SEED = 599

## Define Task
* Define random function
* Define the dataset & labels

Input format:
|a|b|

In [5]:
def get_training_data(input_dim = INPUT_DIM, output_dim = OUTPUT_DIM, data_seed = DATA_SEED):
    torch.manual_seed(data_seed)
    a_vector = torch.arange(input_dim)
    dataset = torch.cartesian_prod(a_vector, a_vector).to(device)

    labels = torch.randint(0, output_dim, (dataset.shape[0],), device=device)
    train_data = dataset
    train_labels = labels
    # For now no test data
    test_data = dataset[0:0]
    test_labels = labels[0:0]
    print(train_data)
    print(train_labels)
    print(train_data.shape)
    print(test_data[:5])
    print(test_labels[:5])
    print(test_data.shape)
    return train_data, train_labels, test_data, test_labels

train_data, train_labels, test_data, test_labels = get_training_data()

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]], device='cuda:0')
tensor([2, 2, 1, 0, 2, 2, 0, 0, 2], device='cuda:0')
torch.Size([9, 2])
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', dtype=torch.int64)
torch.Size([0, 2])


## Define Model

In [114]:

def get_seeded_model(seed = 999, input_dim = INPUT_DIM, output_dim = OUTPUT_DIM):
    cfg = HookedTransformerConfig(
        n_layers = 1,
        n_heads = 1,
        d_model = 2,
        d_head = 2,
        d_mlp = 4,
        attn_only=False,
        act_fn = "relu",
        normalization_type=None,
        d_vocab=input_dim,
        d_vocab_out=output_dim,
        n_ctx=2,
        init_weights=True,
        device=device,
        seed = seed,
    )
    model = HookedTransformer(cfg)
    # Biases are enabled by default
    for name, param in model.named_parameters():
        if "mlp.b_out" in name:
            param.requires_grad = False
    return model

model = get_seeded_model(seed = 993)

## Define Optimizer + Loss

In [7]:
print(train_data)
print(model(train_data))

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]], device='cuda:0')
tensor([[[ 0.1324,  0.1346,  0.1444],
         [-0.1453,  0.5954,  0.4067]],

        [[ 0.1324,  0.1346,  0.1444],
         [ 0.0021,  0.2391,  0.1825]],

        [[ 0.1324,  0.1346,  0.1444],
         [-0.3875,  0.3479,  0.1415]],

        [[ 0.2610,  0.0929,  0.1536],
         [-0.1282,  0.5130,  0.3495]],

        [[ 0.2610,  0.0929,  0.1536],
         [ 0.0193,  0.1562,  0.1249]],

        [[ 0.2610,  0.0929,  0.1536],
         [-0.3633,  0.2177,  0.0502]],

        [[-0.0580, -0.1071, -0.0999],
         [-0.1084,  0.7052,  0.5020]],

        [[-0.0580, -0.1071, -0.0999],
         [ 0.0397,  0.3555,  0.2830]],

        [[-0.0580, -0.1071, -0.0999],
         [-0.3493,  0.5255,  0.2887]]], device='cuda:0',
       grad_fn=<ViewBackward0>)


In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)
print("Uniform loss:")
print(np.log(INPUT_DIM * INPUT_DIM))

tensor(1.1659, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(nan, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
Uniform loss:
2.1972245773362196


## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [9]:
def print_stats(model):
    # Extract the p 2-dimensional tensors, vector i is vec[:, i]
    vec = model.W_U.data

    # Function to compute the angle between two vectors
    def compute_angle(v1, v2):
        cos_theta = torch.dot(v1, v2) / (torch.norm(v1) * torch.norm(v2))
        angle = torch.acos(cos_theta) * (180.0 / np.pi)
        return angle

    # Compute pairwise angles
    # for i in range(vec.shape[1]):
    #     for j in range(i+1, vec.shape[1]):
    #         angle = compute_angle(vec[:, i], vec[:, j])
    #         print(f"Angle between {i} and {j}: {angle.item():.2f}°")
    #     print(f"Norm of vector {i}: {torch.norm(vec[:, i]):.2f}")

print_stats(model)

In [10]:
def train_model(model, train_data, train_labels, test_data, test_labels, num_epochs = num_epochs, loss_target = None):
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=lr, weight_decay=wd, betas=betas
    )
    train_losses = []
    test_losses = []
    model_checkpoints = []
    checkpoint_epochs = []
    if TRAIN_MODEL:
        for epoch in tqdm.tqdm(range(num_epochs)):
            train_logits = model(train_data)
            train_loss = loss_fn(train_logits, train_labels)
            train_loss.backward()
            train_losses.append(train_loss.item())
            if loss_target is not None and train_loss.item() < loss_target:
                print(f"Loss target {loss_target} reached with loss {train_loss.item()} at epoch {epoch}")
                break

            optimizer.step()
            optimizer.zero_grad()

            with torch.inference_mode():
                test_logits = model(test_data)
                test_loss = loss_fn(test_logits, test_labels)
                test_losses.append(test_loss.item())

            if ((epoch+1)%checkpoint_every)==0:
                checkpoint_epochs.append(epoch)
                model_checkpoints.append(copy.deepcopy(model.state_dict()))
                print_stats(model)
                print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
    if TRAIN_MODEL:
        torch.save(
            {
                "model":model.state_dict(),
                "config": model.cfg,
                "checkpoints": model_checkpoints,
                "checkpoint_epochs": checkpoint_epochs,
                "test_losses": test_losses,
                "train_losses": train_losses,
            },
            PTH_LOCATION)

train_model(
    model, train_data, train_labels, test_data, test_labels
)

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 499 Train Loss 0.4106060732104831 Test Loss nan
Epoch 999 Train Loss 0.1807328558316352 Test Loss nan
Epoch 1499 Train Loss 0.16043439349428762 Test Loss nan
Epoch 1999 Train Loss 0.15647240053487907 Test Loss nan
Epoch 2499 Train Loss 0.15523836767544205 Test Loss nan
Epoch 2999 Train Loss 0.15471887337163706 Test Loss nan
Epoch 3499 Train Loss 0.1544579820204664 Test Loss nan
Epoch 3999 Train Loss 0.15431172330198398 Test Loss nan
Epoch 4499 Train Loss 0.15422296296427562 Test Loss nan
Epoch 4999 Train Loss 0.15416575042857242 Test Loss nan
Epoch 5499 Train Loss 0.15412773918838138 Test Loss nan
Epoch 5999 Train Loss 0.15410142859328083 Test Loss nan
Epoch 6499 Train Loss 0.1540828876843727 Test Loss nan
Epoch 6999 Train Loss 0.15406962154348441 Test Loss nan
Epoch 7499 Train Loss 0.15406035555606504 Test Loss nan
Epoch 7999 Train Loss 0.15405372086043706 Test Loss nan
Epoch 8499 Train Loss 0.1540489450639284 Test Loss nan
Epoch 8999 Train Loss 0.15404571483857057 Test Loss nan

In [11]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']

# Activations

In [12]:
def create_cache(model):
    input = train_data
    print("Transposed Input:\n", input.transpose(0, 1))
    logits, cache = model.run_with_cache(input)
    print("Labels: ", train_labels)
    print("Logits of last token:\n", logits[:, -1, :])
    print("Unembed:\n", model.W_U.data)
    print("Last layer before unembed:\n", cache.cache_dict["blocks.0.hook_resid_post"][:, -1, :])
    return cache

cache = create_cache(model)

Transposed Input:
 tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]], device='cuda:0')
Labels:  tensor([2, 2, 1, 0, 2, 2, 0, 0, 2], device='cuda:0')
Logits of last token:
 tensor([[ 28.1620, -19.2949,  41.1054],
        [ 31.0093, -21.3608,  52.8604],
        [-17.7300,  11.5763,  11.5747],
        [ 21.5631, -14.4316,   8.8870],
        [ 30.3038, -20.8193,  47.9917],
        [-17.7300,  11.5763,  11.5747],
        [ 19.9547, -13.3544,   8.1595],
        [ 20.4123, -13.6608,   8.3665],
        [ -0.4468,   0.1360,  10.4339]], device='cuda:0',
       grad_fn=<SliceBackward0>)
Unembed:
 tensor([[-1.0749,  0.7303, -1.1674],
        [ 1.5622, -1.0085, -1.7878]], device='cuda:0')
Last layer before unembed:
 tensor([[-30.2508,  -3.1116],
        [-36.5129,  -5.5977],
        [  3.9719,  -8.9408],
        [-13.6627,   4.0780],
        [-34.1454,  -4.4203],
        [  3.9719,  -8.9408],
        [-12.5916,   3.7855],
        [-12.8963,   3.8687],
        [ -3.8020,  -3.

In [13]:
# Get a list of all activations stored in the cache, especially their names
cache.cache_dict

{'hook_embed': tensor([[[-0.7916,  0.5167],
          [-0.7916,  0.5167]],
 
         [[-0.7916,  0.5167],
          [-1.0963,  0.5999]],
 
         [[-0.7916,  0.5167],
          [ 0.4704, -0.9617]],
 
         [[-1.0963,  0.5999],
          [-0.7916,  0.5167]],
 
         [[-1.0963,  0.5999],
          [-1.0963,  0.5999]],
 
         [[-1.0963,  0.5999],
          [ 0.4704, -0.9617]],
 
         [[ 0.4704, -0.9617],
          [-0.7916,  0.5167]],
 
         [[ 0.4704, -0.9617],
          [-1.0963,  0.5999]],
 
         [[ 0.4704, -0.9617],
          [ 0.4704, -0.9617]]], device='cuda:0'),
 'hook_pos_embed': tensor([[[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
         [[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
         [[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
         [[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
         [[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
         [[-0.8347,  1.5906],
          [-1.1737, -1.0081]],
 
      

# Animations

In [14]:
from manim import *

config.media_width = "80%"
config.verbosity = "WARNING"
config.preview = False

In [50]:
for a, b in model.named_parameters():
    if a.startswith("blocks.0.mlp."):
        print(a, b)
normal = model.get_parameter("blocks.0.mlp.W_in")
bias = model.get_parameter("blocks.0.mlp.b_in")

blocks.0.mlp.W_in Parameter containing:
tensor([[-1.0536],
        [-1.0663]], device='cuda:0', requires_grad=True)
blocks.0.mlp.b_in Parameter containing:
tensor([-0.0978], device='cuda:0', requires_grad=True)
blocks.0.mlp.W_out Parameter containing:
tensor([[-1.0305,  1.1713]], device='cuda:0', requires_grad=True)
blocks.0.mlp.b_out Parameter containing:
tensor([0., 0.], device='cuda:0')


In [72]:
normal_numpy = normal
b = [2, 3]
b_numpy = np.array(b)
np.append(b, 0)

array([2, 3, 0])

## Gathering Data

In [108]:
class VectorParams:
    def __init__(self, values = [], color = WHITE, label = ""):
        self.values = values
        self.color = color
        self.label = label
    def __repr__(self) -> str:
        return str(self.values) + "(" + str(self.color) + ")"

class Data:
    def __init__(self):
        # for each step, store a list of all vectors (one per input)
        self.vectors: list[list[VectorParams]] = [[]]
        # for each step, store some supplementory lines that should be drawn, e.g. zero-lines for MLPs
        self.supp_lines: list[list[tuple[list[float], list[float]]]] = [[]]
        self.steps = 0
        self.current_labels = set()

    def add_vector(self, vector, color = WHITE, label = ""):
        if label not in self.current_labels:
            self.current_labels.add(label)
            self.vectors[self.steps].append(VectorParams(values = vector, color = color, label = label))

    def next_step(self):
        self.steps += 1
        self.vectors.append([])
        self.current_labels = set()
        self.supp_lines.append([])

    def add_vectors_at_hook(self, c: ActivationCache, hook: str, color0 = WHITE, color1 = WHITE, input_labels = None, input_colors = None):
        if input_labels is None:
            input_labels = ["" for i in range(c.cache_dict[hook].shape[0])]
        for i in range(c.cache_dict[hook].shape[0]):
            self.add_vector(c.cache_dict[hook][i][0].cpu(), color = color0, label = input_labels[i][:1])
            self.add_vector(
                c.cache_dict[hook][i][1].cpu(), color=color1 if input_colors is None else input_colors[i], label=input_labels[i][:2]
            )

    def add_mlp_lines(self, model, W_in, b_in):
        normals = model.get_parameter(W_in).transpose(0, 1)
        biases = model.get_parameter(b_in)
        for i, normal in enumerate(normals):
            bias = biases[i]
            # To draw the line, we need to compute the start and end points A, B such that
            # A * normal + bias = 0 and B * normal + bias = 0, and A and B are far away in different directions
            if normal[0].item() == 0:
                # If the first component is zero, we can draw a horizontal line
                A = [-10, - bias.item() / normal[1].item()]
                B = [+10, - bias.item() / normal[1].item()]
            else:
                # Pick Y-coordinates -10 and 10 and compute the corresponding X-coordinates:
                A = [(- bias.item() + 10 * normal[1].item()) / normal[0].item(), -10]
                B = [(- bias.item() - 10 * normal[1].item()) / normal[0].item(), +10]
            self.supp_lines[self.steps].append((A, B))


def compile_data_vectors(model, cache, input_labels=None, input_colors=None):
    # Set default value as list of empty strings
    vectors = Data()
    vectors.add_vectors_at_hook(cache, "hook_embed", color1 = GRAY, input_labels=input_labels, input_colors=input_colors)
    vectors.next_step()
    vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_pre", input_labels=input_labels, input_colors=input_colors)
    vectors.next_step()
    vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_mid", color0 = GRAY, input_labels=input_labels, input_colors=input_colors)
    vectors.next_step()
    vectors.add_vectors_at_hook(cache, "blocks.0.hook_resid_post", color0 = GRAY, input_labels=input_labels, input_colors=input_colors)
    vectors.add_mlp_lines(model, "blocks.0.mlp.W_in", "blocks.0.mlp.b_in")

    print(vectors.vectors)
    return vectors


vectors = compile_data_vectors(model, cache)

[[tensor([-0.5821,  0.9190])(#FFFFFF)], [tensor([-1.2485,  0.7991])(#FFFFFF)], [tensor([ 0.2676, -1.6296])(#888888)], [tensor([ 0.2676, -1.6296])(#888888)]]


In [73]:
def change_font_size(labeled_arrow: LabeledArrow, new_size):
    # print(labeled_arrow, labeled_arrow.submobjects)
    # print(labeled_arrow.submobjects[-1].font_size)
    if not isinstance(labeled_arrow, LabeledArrow):
        return
    label = labeled_arrow.submobjects[-1]
    box = labeled_arrow.submobjects[-2]
    if not isinstance(box, BackgroundRectangle):
        box = labeled_arrow.submobjects[-3]
    coords = label.get_center()
    # print(new_size)
    labeled_arrow.submobjects[-1] = MathTex(
        label.get_tex_string(), color=label.color, font_size=new_size
    )
    # print("size=", labeled_arrow.submobjects[-1].font_size)
    label = labeled_arrow.submobjects[-1]
    label.move_to(coords)
    box.width = label.width + 2 * box.buff
    box.height = label.height + 2 * box.buff

## Scene Definition

In [100]:
DOT_SCALE = 0.01
class VisualizeTransformer(MovingCameraScene):
    def construct(self):
        print("v=", vectors.vectors)
        axes = Axes(
            x_range = [-20, 20, 1],
            y_range = [-20, 20, 1],
            x_axis_config={
                "numbers_to_include": np.arange(-18, 18.1, 3),
                "font_size": 24
            },
            y_axis_config={
                "numbers_to_include": np.arange(-18, 18.1, 3), 
                "font_size": 24            
            },
            x_length = 40,
            y_length = 40,
            axis_config={"color": GREEN}
        )

        scale = ValueTracker(2)

        dots = VGroup()
        def update_scale(self):
            return
            # TODO: Make the scaling nicer
            self.stroke_width = 6 * scale.get_value()
            change_font_size(self, 48 * scale.get_value())
            # print("New font size: ", self.font_size)

        # Embedding arrows
        for i, t in enumerate(vectors.vectors[0]):
            # print(t, t.numpy())
            # arrow = LabeledArrow(
            #     start=ORIGIN,
            #     end=np.append(t.values.numpy(), 0),
            #     buff = 0,
            #     label = t.label,
            #     label_frame = False,
            #     label_color=YELLOW,
            #     color = t.color,
            #     max_stroke_width_to_length_ratio = 100,
            # )

            # arrow.add_updater(update_scale)
            # arrows.add(arrow)
            dot = LabeledDot(
                point=np.append(t.values.numpy(), 0),
                label=t.label,
                color=t.color,
                radius=DOT_SCALE * (len(vectors.vectors[0]) - i) + 0.2,
            )

            dot.set_opacity(0.5)

            dot.add_updater(update_scale)
            dots.add(dot)

        # Transitioning the arrows through the model
        self.add(axes, axes.get_axis_labels(), dots)
        for step in range(1, len(vectors.vectors)):
            new_dots = VGroup()
            transition_arrows = VGroup()
            for i, t in enumerate(vectors.vectors[step]):
                # print(t, t.numpy())
                # new_arrow = LabeledArrow(
                #     start=ORIGIN,
                #     end=np.append(t.values.numpy(), 0),
                #     buff=0,
                #     label=t.label,
                #     label_frame=False,
                #     label_color=YELLOW,
                #     color=t.color,
                #     max_stroke_width_to_length_ratio=100,
                # )
                # new_arrow.add_updater(update_scale)
                # new_arrows.add(new_arrow)
                new_dot = LabeledDot(
                    point=np.append(t.values.numpy(), 0),
                    label=t.label,
                    color=t.color,
                    radius = DOT_SCALE * (len(vectors.vectors[step]) - i) + 0.2,
                )
                new_dot.set_opacity(0.5)
                new_dot.add_updater(update_scale)
                new_dots.add(new_dot)

                transition_arrow = Arrow(
                    start=dots[i].arc_center,
                    end=new_dots[i].arc_center,
                    buff=0,
                    color=RED,
                )
                transition_arrow.add_updater(update_scale)
                transition_arrows.add(transition_arrow)
            
            mlp_lines = VGroup()
            for i, t in enumerate(vectors.supp_lines[step]):
                # print(t)
                mlp_line = Line(
                    start=np.append(t[0],0),
                    end=np.append(t[1],0),
                    buff=0,
                    color=BLUE,
                )
                mlp_line.add_updater(update_scale)
                mlp_lines.add(mlp_line)

            view = SurroundingRectangle(new_dots)
            factor = max(
                view.width / self.camera.frame_width,
                view.height / self.camera.frame_height,
            )
            print(
                factor,
                self.camera.frame_width, view.width,
                self.camera.frame_height, view.height,
            )
            self.wait()
            self.play(FadeIn(transition_arrows), FadeIn(mlp_lines), self.camera.auto_zoom(view, margin = 2), scale.animate.set_value(scale.get_value() * factor))
            self.wait()
            self.play(
                ReplacementTransform(dots, new_dots)
            )
            self.wait()
            self.play(FadeOut(transition_arrows), FadeOut(mlp_lines))
            self.wait()
            dots = new_dots

        # Unembedding Arrows
        embedding_arrows = VGroup()
        data = model.W_U.data
        print("unembed: ", data)
        for i in range(data.size()[1]):
            embedding_arrow = LabeledArrow(
                start=ORIGIN,
                end=[data[0, i].item(), data[1, i].item(), 0],
                label=str(i),
                color=LIGHT_PINK,
                buff=0,
                max_stroke_width_to_length_ratio=100,
            )
            embedding_arrows.add(embedding_arrow)
        self.play(FadeIn(embedding_arrows))
        self.wait()

# v = VisualizeTransformer()
# v.construct()

In [18]:
%env CUDA_LAUNCH_BLOCKING=1
%env TORCH_USE_CUDA_DSA=1

env: CUDA_LAUNCH_BLOCKING=1
env: TORCH_USE_CUDA_DSA=1


In [98]:
A, B = vectors.supp_lines[3][0]
print(A, B)
W, bias = model.get_parameter("blocks.0.mlp.W_in"), model.get_parameter("blocks.0.mlp.b_in")
print(W.data, bias.item())
def dot(a, b):
    return a[0] * b[0] + a[1] * b[1]
print(dot(A, W), dot(B, W))

[-65.82968603798591, -10] [63.45951549715765, 10]
tensor([[-0.3886],
        [ 2.5119]], device='cuda:0') 0.46049827337265015
tensor([0.4605], device='cuda:0', grad_fn=<AddBackward0>) tensor([0.4605], device='cuda:0', grad_fn=<AddBackward0>)


## Summary and Video Generation

In [None]:
%%manim -qh Video

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'


input_dim = 5
output_dim = 2
train_data, train_labels, test_data, test_labels = get_training_data(input_dim, output_dim, data_seed=997)
model = get_seeded_model(997, input_dim, output_dim)
train_model(model, train_data, train_labels, test_data, test_labels, num_epochs = 10000, loss_target = 1/(input_dim ** 2 * 4))
cache = create_cache(model)
arrow_labels = ["".join([str(d.item()) for d in v]) for v in train_data]
colors = [BLUE, YELLOW, GREEN, RED]
arrow_colors = [colors[l] for l in train_labels]
vectors = compile_data_vectors(model, cache, input_labels=arrow_labels, input_colors = arrow_colors)
print("Labels: ", arrow_labels)

class Video(VisualizeTransformer):
    def construct(self):
        VisualizeTransformer.construct(self)

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [1, 4],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [2, 4],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3],
        [3, 4],
        [4, 0],
        [4, 1],
        [4, 2],
        [4, 3],
        [4, 4]], device='cuda:0')
tensor([1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
        1], device='cuda:0')
torch.Size([25, 2])
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
tensor([], device='cuda:0', dtype=torch.int64)
torch.Size([0, 2])


  0%|          | 0/10000 [00:00<?, ?it/s]