In [1]:
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

from model import AttentionOnlyTransformer
from generate_data import generate_rrt

device = torch.device(
    "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
)
print(device)

torch.manual_seed(67)

NUM_LAYERS = 1
SEQ_LEN = (15, 35)
MAX_CONTEXT_LEN = 80

cuda


In [2]:
seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=5)

print(f"Dataset Shape: {data.shape}, Prefix Length: {seq_len}")
print(f"First sequence sample (first 10 tokens):\n{data[0, :10]}")

Dataset Shape: torch.Size([32, 80]), Prefix Length: 5
First sequence sample (first 10 tokens):
tensor([120, 167, 140,  36, 246, 120, 167, 140,  36, 246])


In [3]:
# Create the model
model = AttentionOnlyTransformer(
    vocab_size=256,
    d_model=64,
    n_layers=NUM_LAYERS,
    n_heads=4,
    max_context_len=MAX_CONTEXT_LEN,
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nModel architecture:\n{model}")

Model parameters: 49,152

Model architecture:
AttentionOnlyTransformer(
  (token_embedding): Embedding(256, 64)
  (attention_blocks): ModuleList(
    (0): AttentionBlock(
      (W_qkv): Linear(in_features=64, out_features=192, bias=False)
      (W_o): Linear(in_features=64, out_features=64, bias=False)
    )
  )
  (unembed): Linear(in_features=64, out_features=256, bias=False)
)


In [4]:
optimizer = optim.AdamW(
    model.parameters(), 
    lr=1e-3,           # Relatively high for small models
    betas=(0.9, 0.98), # Standard 'Transformer' betas
    weight_decay=0.01  # Helps prevent the 'junk drawer' neuron problem
)

epochs = 20000
losses = []
accuracies = []
induction_accuracies = []

# Training
for epoch in range(epochs):

    # Generate data using a fixed prefix length
    seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=SEQ_LEN)

    # Move data to the specified device
    data = data.to(device)
    
    input_data = data[:, :-1]
    target_data = data[:, 1:]

    # Forward pass
    logits = model.forward(input_data)

    # [b, s, vocab_size] -> [b*s, vocab_size]
    input_to_ce = logits.view(-1, model.vocab_size)

    # [b, vocab_size] -> [b*vocab_size], equivalent to flattening a 2D tensor into a 1D tensor
    target_flat = target_data.reshape(-1)

    # Calculate loss
    loss = F.cross_entropy(input_to_ce, target_flat)
    losses.append(loss.item())

    # Calculate overall accuracy
    predictions = logits.argmax(dim=-1)
    accuracy = (predictions == target_data).float().mean().item()
    accuracies.append(accuracy)

    # Induction accuracy (second half only where the pattern repeats)
    induction_acc = (predictions[:, seq_len:] == target_data[:, seq_len:]).float().mean().item()
    induction_accuracies.append(induction_acc)

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    optimizer.zero_grad()

    # Print every N epochs
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={accuracy:.2%}, Induction Acc={induction_acc:.2%}")

Epoch 0: Loss=5.8013, Acc=0.40%, Induction Acc=0.41%
Epoch 100: Loss=5.6815, Acc=0.16%, Induction Acc=0.20%
Epoch 200: Loss=5.5915, Acc=0.24%, Induction Acc=0.26%
Epoch 300: Loss=5.5660, Acc=0.87%, Induction Acc=1.06%
Epoch 400: Loss=5.5711, Acc=2.25%, Induction Acc=3.59%
Epoch 500: Loss=5.5011, Acc=2.10%, Induction Acc=2.78%
Epoch 600: Loss=5.4214, Acc=3.72%, Induction Acc=4.41%
Epoch 700: Loss=5.4383, Acc=3.24%, Induction Acc=4.36%
Epoch 800: Loss=5.4044, Acc=3.44%, Induction Acc=4.60%
Epoch 900: Loss=5.4400, Acc=3.20%, Induction Acc=4.62%
Epoch 1000: Loss=5.3835, Acc=3.72%, Induction Acc=4.89%
Epoch 1100: Loss=5.2066, Acc=5.74%, Induction Acc=6.80%
Epoch 1200: Loss=5.3661, Acc=3.96%, Induction Acc=5.64%
Epoch 1300: Loss=5.2560, Acc=4.83%, Induction Acc=5.95%
Epoch 1400: Loss=5.2703, Acc=4.35%, Induction Acc=5.31%
Epoch 1500: Loss=5.3002, Acc=4.55%, Induction Acc=6.80%
Epoch 1600: Loss=5.2740, Acc=4.91%, Induction Acc=6.73%
Epoch 1700: Loss=5.2610, Acc=4.47%, Induction Acc=7.02%
Epoc

In [None]:
# Save the model parameters
results_dir = Path.cwd().parent / "results"
results_dir.mkdir(exist_ok=True)

model_path = results_dir / "1L_varied_model.pt"
torch.save(model.state_dict(), model_path)