In [None]:
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn.functional as F
import torch.optim as optim

from model import AttentionOnlyTransformer
from generate_data import generate_rrt

device = torch.device(
    "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
)
print(device)

torch.manual_seed(67)

NUM_LAYERS = 1
SEQ_LEN = 20
MAX_CONTEXT_LEN = 80

cuda


In [7]:
seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=5)

print(f"Dataset Shape: {data.shape}, Prefix Length: {seq_len}")
print(f"First sequence sample (first 10 tokens):\n{data[0, :10]}")

Dataset Shape: torch.Size([32, 80]), Prefix Length: 5
First sequence sample (first 10 tokens):
tensor([120, 167, 140,  36, 246, 120, 167, 140,  36, 246])


In [None]:
# Create the model
model = AttentionOnlyTransformer(
    vocab_size=256,
    d_model=64,
    n_layers=NUM_LAYERS,
    n_heads=4,
    max_context_len=MAX_CONTEXT_LEN,
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nModel architecture:\n{model}")

Model parameters: 49,152

Model architecture:
AttentionOnlyTransformer(
  (token_embedding): Embedding(256, 64)
  (attention_blocks): ModuleList(
    (0): AttentionBlock(
      (W_qkv): Linear(in_features=64, out_features=192, bias=False)
      (W_o): Linear(in_features=64, out_features=64, bias=False)
    )
  )
  (unembed): Linear(in_features=64, out_features=256, bias=False)
)


In [9]:
optimizer = optim.AdamW(
    model.parameters(), 
    lr=1e-3,           # Relatively high for small models
    betas=(0.9, 0.98), # Standard 'Transformer' betas
    weight_decay=0.01  # Helps prevent the 'junk drawer' neuron problem
)

epochs = 5000
losses = []
accuracies = []
induction_accuracies = []

# Training
for epoch in range(epochs):

    # Generate data using a fixed prefix length
    seq_len, data = generate_rrt(num_samples=32, length=MAX_CONTEXT_LEN, vocab_size=256, seq_len=SEQ_LEN)

    # Move data to the specified device
    data = data.to(device)
    
    input_data = data[:, :-1]
    target_data = data[:, 1:]

    # Forward pass
    logits = model.forward(input_data)

    # [b, s, vocab_size] -> [b*s, vocab_size]
    input_to_ce = logits.view(-1, model.vocab_size)

    # [b, vocab_size] -> [b*vocab_size], equivalent to flattening a 2D tensor into a 1D tensor
    target_flat = target_data.reshape(-1)

    # Calculate loss
    loss = F.cross_entropy(input_to_ce, target_flat)
    losses.append(loss.item())

    # Calculate overall accuracy
    predictions = logits.argmax(dim=-1)
    accuracy = (predictions == target_data).float().mean().item()
    accuracies.append(accuracy)

    # Induction accuracy (second half only where the pattern repeats)
    induction_acc = (predictions[:, seq_len:] == target_data[:, seq_len:]).float().mean().item()
    induction_accuracies.append(induction_acc)

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    optimizer.zero_grad()

    # Print every N epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss={loss.item():.4f}, Acc={accuracy:.2%}, Induction Acc={induction_acc:.2%}")

Epoch 0: Loss=5.8025, Acc=0.28%, Induction Acc=0.21%
Epoch 10: Loss=5.7272, Acc=0.51%, Induction Acc=0.42%
Epoch 20: Loss=5.7244, Acc=0.36%, Induction Acc=0.37%
Epoch 30: Loss=5.7267, Acc=0.32%, Induction Acc=0.32%
Epoch 40: Loss=5.7141, Acc=0.36%, Induction Acc=0.32%
Epoch 50: Loss=5.6760, Acc=0.16%, Induction Acc=0.11%
Epoch 60: Loss=5.6841, Acc=0.40%, Induction Acc=0.48%
Epoch 70: Loss=5.6981, Acc=0.36%, Induction Acc=0.37%
Epoch 80: Loss=5.6974, Acc=0.28%, Induction Acc=0.21%
Epoch 90: Loss=5.6770, Acc=0.36%, Induction Acc=0.42%
Epoch 100: Loss=5.6830, Acc=0.47%, Induction Acc=0.58%
Epoch 110: Loss=5.6762, Acc=0.47%, Induction Acc=0.37%
Epoch 120: Loss=5.6440, Acc=0.20%, Induction Acc=0.16%
Epoch 130: Loss=5.6755, Acc=0.28%, Induction Acc=0.37%
Epoch 140: Loss=5.6516, Acc=0.28%, Induction Acc=0.32%
Epoch 150: Loss=5.6224, Acc=0.36%, Induction Acc=0.32%
Epoch 160: Loss=5.6384, Acc=0.24%, Induction Acc=0.16%
Epoch 170: Loss=5.6412, Acc=0.24%, Induction Acc=0.16%
Epoch 180: Loss=5.620

In [None]:
# Save the model parameters
results_dir = Path.cwd().parent / "results"
results_dir.mkdir(exist_ok=True)

model_path = results_dir / "1L_fixed_model.pt"
torch.save(model.state_dict(), model_path)