In [1]:
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

from model import AttentionOnlyTransformer
from generate_data import generate_rrt

device = torch.device(
    "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
)
print(device)

torch.manual_seed(67)

NUM_LAYERS = 2
SEQ_LEN = (20, 30)
MAX_CONTEXT_LEN = 80
OOD_SEQ_LENS = (10, 15, 35, 40)  # OOD sequence lengths for generalization eval

cuda


In [2]:
seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=5)

print(f"Dataset Shape: {data.shape}, Prefix Length: {seq_len}")
print(f"First sequence sample (first 10 tokens):\n{data[0, :10]}")

Dataset Shape: torch.Size([32, 80]), Prefix Length: 5
First sequence sample (first 10 tokens):
tensor([120, 167, 140,  36, 246, 120, 167, 140,  36, 246])


In [None]:
# Create the model
model = AttentionOnlyTransformer(
    vocab_size=256,
    d_model=64,
    n_layers=NUM_LAYERS,
    n_heads=4,
    max_context_len=MAX_CONTEXT_LEN,
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nModel architecture:\n{model}")

Model parameters: 65,536

Model architecture:
AttentionOnlyTransformer(
  (token_embedding): Embedding(256, 64)
  (attention_blocks): ModuleList(
    (0-1): 2 x AttentionBlock(
      (W_qkv): Linear(in_features=64, out_features=192, bias=False)
      (W_o): Linear(in_features=64, out_features=64, bias=False)
    )
  )
  (unembed): Linear(in_features=64, out_features=256, bias=False)
)


In [4]:
optimizer = optim.AdamW(
    model.parameters(), 
    lr=1e-3,           # Relatively high for small models
    betas=(0.9, 0.98), # Standard 'Transformer' betas
    weight_decay=0.01  # Helps prevent the 'junk drawer' neuron problem
)

epochs = 20000
losses = []
accuracies = []
induction_accuracies = []

# Training history for visualization (saved every 100 epochs, includes OOD evals)
training_history = {
    "epochs": [],
    "losses": [],
    "accuracies": [],
    "induction_accs": [],
    "ood_induction_accs": {seq_len: [] for seq_len in OOD_SEQ_LENS},  # keyed by OOD seq_len
    "attention_snapshots": {},  # epoch -> attention patterns (only when seq_len == 20)
}

# Training
for epoch in range(epochs):

    # Generate data using a varied prefix length
    seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=SEQ_LEN)

    # Move data to the specified device
    data = data.to(device)
    
    input_data = data[:, :-1]
    target_data = data[:, 1:]

    # Forward pass
    logits = model.forward(input_data)

    # [b, s, vocab_size] -> [b*s, vocab_size]
    input_to_ce = logits.view(-1, model.vocab_size)

    # [b, vocab_size] -> [b*vocab_size], equivalent to flattening a 2D tensor into a 1D tensor
    target_flat = target_data.reshape(-1)

    # Calculate loss
    loss = F.cross_entropy(input_to_ce, target_flat)
    losses.append(loss.item())

    # Calculate overall accuracy
    predictions = logits.argmax(dim=-1)
    accuracy = (predictions == target_data).float().mean().item()
    accuracies.append(accuracy)

    # Induction accuracy (second half only where the pattern repeats)
    induction_acc = (predictions[:, seq_len:] == target_data[:, seq_len:]).float().mean().item()
    induction_accuracies.append(induction_acc)

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    optimizer.zero_grad()

    # Save metrics every 100 epochs (includes OOD seq_len=15 eval)
    if epoch % 100 == 0:
        training_history["epochs"].append(epoch)
        training_history["losses"].append(loss.item())
        training_history["accuracies"].append(accuracy)
        training_history["induction_accs"].append(induction_acc)

        # Evaluate out-of-distribution induction accuracy at multiple seq_lens
        ood_accs = {}
        with torch.no_grad():
            for ood_len in OOD_SEQ_LENS:
                _, ood_data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=ood_len)
                ood_data = ood_data.to(device)
                ood_logits = model(ood_data[:, :-1])
                ood_predictions = ood_logits.argmax(dim=-1)
                ood_targets = ood_data[:, 1:]
                ood_acc = (ood_predictions[:, ood_len:] == ood_targets[:, ood_len:]).float().mean().item()
                ood_accs[ood_len] = ood_acc
                training_history["ood_induction_accs"][ood_len].append(ood_acc)

        # Save attention snapshots by running a dedicated forward pass with seq_len=20
        with torch.no_grad():
            _, snap_data = generate_rrt(num_samples=1, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=20)
            snap_data = snap_data.to(device)
            _ = model(snap_data[:, :-1])  # run forward to populate attn_weights
            attention_patterns = []
            for block in model.attention_blocks:
                # attn_weights shape: [batch, n_heads, seq, seq] -> take first batch, detach
                attention_patterns.append(block.attn_weights[0].detach().cpu())
            # Stack to [n_layers, n_heads, seq, seq]
            training_history["attention_snapshots"][epoch] = torch.stack(attention_patterns)
        
        ood_str = ", ".join([f"OOD_{k}={v:.2%}" for k, v in ood_accs.items()])
        print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={accuracy:.2%}, Induction={induction_acc:.2%}, {ood_str}")

Epoch 0: Loss=5.8080, Acc=0.67%, Induction=0.60%, OOD_10=0.27%, OOD_15=0.15%, OOD_35=0.57%, OOD_40=0.16%
Epoch 100: Loss=5.6703, Acc=0.24%, Induction=0.29%, OOD_10=0.41%, OOD_15=0.59%, OOD_35=0.36%, OOD_40=0.32%
Epoch 200: Loss=5.5801, Acc=0.75%, Induction=0.77%, OOD_10=0.91%, OOD_15=0.29%, OOD_35=0.43%, OOD_40=0.48%
Epoch 300: Loss=5.5701, Acc=2.10%, Induction=2.78%, OOD_10=4.30%, OOD_15=3.08%, OOD_35=2.06%, OOD_40=1.28%
Epoch 400: Loss=5.4718, Acc=2.85%, Induction=3.92%, OOD_10=7.16%, OOD_15=5.47%, OOD_35=3.12%, OOD_40=2.16%
Epoch 500: Loss=5.4473, Acc=3.28%, Induction=4.87%, OOD_10=8.11%, OOD_15=6.59%, OOD_35=2.91%, OOD_40=2.64%
Epoch 600: Loss=5.4178, Acc=2.93%, Induction=4.27%, OOD_10=9.19%, OOD_15=6.10%, OOD_35=3.48%, OOD_40=2.64%
Epoch 700: Loss=5.3590, Acc=4.11%, Induction=5.50%, OOD_10=8.06%, OOD_15=6.69%, OOD_35=4.40%, OOD_40=3.37%
Epoch 800: Loss=5.2196, Acc=4.67%, Induction=5.83%, OOD_10=10.24%, OOD_15=7.71%, OOD_35=3.05%, OOD_40=3.04%
Epoch 900: Loss=5.2197, Acc=4.91%, Ind

In [None]:
# Save the model parameters and training history
results_dir = Path.cwd().parent / "results"
results_dir.mkdir(exist_ok=True)

model_path = results_dir / f"2L_varied_model_{epochs}.pt"
torch.save(model.state_dict(), model_path)

In [6]:
# Training (continued from epoch 30000)
epoch_offset = epochs
epochs = 40000
for epoch in range(epoch_offset, epoch_offset + epochs):

    # Generate data using a varied prefix length
    seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=SEQ_LEN)

    # Move data to the specified device
    data = data.to(device)
    
    input_data = data[:, :-1]
    target_data = data[:, 1:]

    # Forward pass
    logits = model.forward(input_data)

    # [b, s, vocab_size] -> [b*s, vocab_size]
    input_to_ce = logits.view(-1, model.vocab_size)

    # [b, vocab_size] -> [b*vocab_size], equivalent to flattening a 2D tensor into a 1D tensor
    target_flat = target_data.reshape(-1)

    # Calculate loss
    loss = F.cross_entropy(input_to_ce, target_flat)
    losses.append(loss.item())

    # Calculate overall accuracy
    predictions = logits.argmax(dim=-1)
    accuracy = (predictions == target_data).float().mean().item()
    accuracies.append(accuracy)

    # Induction accuracy (second half only where the pattern repeats)
    induction_acc = (predictions[:, seq_len:] == target_data[:, seq_len:]).float().mean().item()
    induction_accuracies.append(induction_acc)

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    optimizer.zero_grad()

    # Save metrics every 100 epochs (includes OOD seq_len=15 eval)
    if epoch % 100 == 0:
        training_history["epochs"].append(epoch)
        training_history["losses"].append(loss.item())
        training_history["accuracies"].append(accuracy)
        training_history["induction_accs"].append(induction_acc)

        # Evaluate out-of-distribution induction accuracy at multiple seq_lens
        ood_accs = {}
        with torch.no_grad():
            for ood_len in OOD_SEQ_LENS:
                _, ood_data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=ood_len)
                ood_data = ood_data.to(device)
                ood_logits = model(ood_data[:, :-1])
                ood_predictions = ood_logits.argmax(dim=-1)
                ood_targets = ood_data[:, 1:]
                ood_acc = (ood_predictions[:, ood_len:] == ood_targets[:, ood_len:]).float().mean().item()
                ood_accs[ood_len] = ood_acc
                training_history["ood_induction_accs"][ood_len].append(ood_acc)

        # Save attention snapshots by running a dedicated forward pass with seq_len=20
        with torch.no_grad():
            _, snap_data = generate_rrt(num_samples=1, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=20)
            snap_data = snap_data.to(device)
            _ = model(snap_data[:, :-1])  # run forward to populate attn_weights
            attention_patterns = []
            for block in model.attention_blocks:
                # attn_weights shape: [batch, n_heads, seq, seq] -> take first batch, detach
                attention_patterns.append(block.attn_weights[0].detach().cpu())
            # Stack to [n_layers, n_heads, seq, seq]
            training_history["attention_snapshots"][epoch] = torch.stack(attention_patterns)
        
        ood_str = ", ".join([f"OOD{k}={v:.2%}" for k, v in ood_accs.items()])
        print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={accuracy:.2%}, Induction={induction_acc:.2%}, {ood_str}")

Epoch 20000: Loss=1.6728, Acc=70.85%, Induction=98.38%, OOD10=95.88%, OOD15=91.85%, OOD35=88.28%, OOD40=82.61%
Epoch 20100: Loss=1.9788, Acc=65.31%, Induction=97.72%, OOD10=94.20%, OOD15=91.31%, OOD35=88.85%, OOD40=81.17%
Epoch 20200: Loss=1.8109, Acc=68.75%, Induction=99.07%, OOD10=93.25%, OOD15=89.01%, OOD35=90.48%, OOD40=82.61%
Epoch 20300: Loss=2.0822, Acc=63.92%, Induction=97.24%, OOD10=95.15%, OOD15=91.89%, OOD35=88.57%, OOD40=81.81%
Epoch 20400: Loss=1.9072, Acc=66.93%, Induction=97.82%, OOD10=94.97%, OOD15=89.36%, OOD35=88.49%, OOD40=81.97%
Epoch 20500: Loss=1.8916, Acc=67.25%, Induction=98.70%, OOD10=95.97%, OOD15=91.94%, OOD35=89.06%, OOD40=82.93%
Epoch 20600: Loss=1.5650, Acc=72.23%, Induction=95.39%, OOD10=96.01%, OOD15=91.16%, OOD35=88.42%, OOD40=82.85%
Epoch 20700: Loss=2.0975, Acc=63.65%, Induction=97.00%, OOD10=94.20%, OOD15=87.99%, OOD35=88.28%, OOD40=82.13%
Epoch 20800: Loss=2.1645, Acc=62.18%, Induction=96.94%, OOD10=93.89%, OOD15=90.43%, OOD35=87.78%, OOD40=81.97%
E

In [None]:
# Save the model parameters and training history
results_dir = Path.cwd().parent / "results"
results_dir.mkdir(exist_ok=True)

model_path = results_dir / f"2L_varied_model_{epoch_offset + epochs}.pt"
torch.save(model.state_dict(), model_path)

history_path = results_dir / "2L_varied_training_history.pt"
torch.save(training_history, history_path)