## 0. Notebook Setup

In [51]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset, random_split
import numpy as np
import einops
from transformer_lens import HookedTransformerConfig, HookedTransformer
from tqdm import tqdm
import copy

# if torch.backends.mps.is_available():
#     device = torch.device('mps')

device = torch.device('cpu')

checkpoint_every = 1000
checkpoint_path = "/Users/williamyang/Documents/local_projects/arithmetic/checkpoints/"
checkpoint_epochs, model_checkpoints = [], []

## 1. Model Training

#### 1.1 Task

In [52]:
P = 113

a_vector = einops.repeat(torch.arange(P), 'i -> (i j)', j=P)
b_vector = einops.repeat(torch.arange(P), 'j -> (i j)', i=P)
equals_vector = einops.repeat(torch.tensor(P), '-> (i j)', i=P, j=P)

X = torch.stack([a_vector, b_vector, equals_vector], dim=1)
Y = (X[:, 0] + X[:, 1]) % P

X, Y = X.to(device), Y.to(device)

dataset = TensorDataset(X, Y)

train_size = int(0.3 * len(dataset))
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

#### 1.2 Model

In [53]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = 'relu',
    normalization_type = None,
    d_vocab = P+1,
    d_vocab_out = P,
    n_ctx = 3,
    init_weights = True
)

model = HookedTransformer(cfg)

In [54]:
for name, param in model.named_parameters():
    if 'b_' in name:
        param.requires_grad = False

#### 1.3 Optimizer

In [55]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1, betas=(0.9, 0.98))

def loss_fn(logits, labels):
    if len(logits.shape) == 3:
        logits = logits[:, -1].to(device)
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

#### 1.4 Train

In [56]:
num_epochs = 25000

train_losses, test_losses = [], []

for epoch in tqdm(range(num_epochs)):
    train_logits = model(train_dataset[:][0])
    train_loss = loss_fn(train_logits, train_dataset[:][1])
    train_loss.backward()
    train_losses.append(train_loss.item())

    with torch.inference_mode():
        test_logits = model(test_dataset[:][0])
        test_loss = loss_fn(test_logits, test_dataset[:][1])
        test_losses.append(test_loss.item())

    if (epoch+1)%checkpoint_every:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))

  1%|          | 236/25000 [00:31<55:28,  7.44it/s]  


KeyboardInterrupt: 