In [1]:
# Setup and Imports
import torch
import torch.nn as nn
import numpy as np
import plotly.express as px
from transformer_lens import HookedTransformer, HookedTransformerConfig
import transformer_lens.utils as utils
import tqdm.auto as tqdm
import copy

In [2]:
# Use the GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


In [3]:

# --- 1. Model and Task Configuration ---
p = 113  # Modulus for our arithmetic task
frac_train = 0.3 # Fraction of the data to use for training


In [5]:
# Model Configuration
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=4,
    d_model=128,
    d_head=32,
    d_mlp=512,
    act_fn="relu",
    normalization_type=None,
    d_vocab=p + 1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device=device,
    seed=999,
)

model = HookedTransformer(cfg)
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)

In [6]:
# Disable biases for simplicity
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False


In [7]:

# --- 2. Dataset Generation ---
a_vector = torch.arange(p).repeat(p)
b_vector = torch.arange(p).repeat_interleave(p)
equals_vector = torch.full_like(a_vector, p)

dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)
labels = (dataset[:, 0] + dataset[:, 1]) % p


In [8]:
a_vector

tensor([  0,   1,   2,  ..., 110, 111, 112])

In [9]:
b_vector

tensor([  0,   0,   0,  ..., 112, 112, 112])

In [10]:
equals_vector

tensor([113, 113, 113,  ..., 113, 113, 113])

In [11]:
dataset

tensor([[  0,   0, 113],
        [  1,   0, 113],
        [  2,   0, 113],
        ...,
        [110, 112, 113],
        [111, 112, 113],
        [112, 112, 113]], device='cuda:0')

In [12]:
labels

tensor([  0,   1,   2,  ..., 109, 110, 111], device='cuda:0')

In [13]:
# Train/Test Split
torch.manual_seed(42)
indices = torch.randperm(p * p)
cutoff = int(p * p * frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

In [15]:
cutoff, indices


(3830, tensor([11691,  4308,  4501,  ...,  7630,  3981,  8634]))

In [17]:
# --- 3. Training ---
def loss_fn(logits, labels):
    if len(logits.shape) == 3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(1)).squeeze(1)
    return -correct_log_probs.mean()


# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1.0, betas=(0.9, 0.98))

num_epochs = 25000
train_losses, test_losses = [], []

for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    with torch.no_grad():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)

    train_losses.append(train_loss.item())
    test_losses.append(test_loss.item())

    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Train Loss {train_loss.item():.4f}, Test Loss {test_loss.item():.4f}")


  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch 0: Train Loss 4.7367, Test Loss 4.7315
Epoch 100: Train Loss 2.8639, Test Loss 7.5241
Epoch 200: Train Loss 0.0270, Test Loss 19.5011
Epoch 300: Train Loss 0.0089, Test Loss 20.1920
Epoch 400: Train Loss 0.0029, Test Loss 21.4087
Epoch 500: Train Loss 0.0010, Test Loss 22.7182
Epoch 600: Train Loss 0.0003, Test Loss 24.0754
Epoch 700: Train Loss 0.0001, Test Loss 25.4503
Epoch 800: Train Loss 0.0000, Test Loss 26.8303
Epoch 900: Train Loss 0.0000, Test Loss 28.1919
Epoch 1000: Train Loss 0.0000, Test Loss 29.4916
Epoch 1100: Train Loss 0.0000, Test Loss 30.6748
Epoch 1200: Train Loss 0.0000, Test Loss 31.6300
Epoch 1300: Train Loss 0.0000, Test Loss 32.2312
Epoch 1400: Train Loss 0.0000, Test Loss 32.4686
Epoch 1500: Train Loss 0.0000, Test Loss 32.4278
Epoch 1600: Train Loss 0.0000, Test Loss 32.2515
Epoch 1700: Train Loss 0.0000, Test Loss 32.0463
Epoch 1800: Train Loss 0.0000, Test Loss 31.8388
Epoch 1900: Train Loss 0.0000, Test Loss 31.6481
Epoch 2000: Train Loss 0.0000, Tes

In [18]:
# --- 4. Visualization ---
def plot_loss_curves(train_losses, test_losses):
    fig = px.line(y=[train_losses, test_losses], log_y=True,
                  labels={"x": "Epoch", "y": "Loss", "variable": "Dataset"},
                  title="Training and Test Loss for Modular Addition")
    fig.data[0].name = "Train"
    fig.data[1].name = "Test"
    fig.show()

plot_loss_curves(train_losses, test_losses)