## Install & import relevant libraries

In [1]:
%pip -q install transformer_lens
%pip -q install circuitsvis
%pip -q install git+https://github.com/neelnanda-io/neel-plotly.git

In [2]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
from typing import List, Union, Optional
from functools import partial
import copy
from neel_plotly.plot import line
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [27]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/grokking_xor_7bits.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

In [5]:
# Define function to generate the dataset
from itertools import product
def generate_xor_dataset(bit_length=1, n_samples=None, device='cpu', seed=None):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
    if n_samples is None:
        all_pairs = list(product([0, 1], repeat=2 * bit_length))
        #random.shuffle(all_pairs)
        inputs = torch.tensor(all_pairs, dtype=torch.long, device=device)
    else:
        inputs = torch.randint(
            0, 2, size=(n_samples, 2 * bit_length), dtype=torch.long, device=device
        )

    a = inputs[:, :bit_length]
    b = inputs[:, bit_length:]
    targets = torch.bitwise_xor(a, b)

    return inputs, targets

seed = 598
dataset, labels = generate_xor_dataset(bit_length=7, device=device, seed=seed)

In [7]:
print(dataset.shape)
print(labels.shape)
print(dataset)
print(labels)

torch.Size([16384, 14])
torch.Size([16384, 7])
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 1, 0],
        ...,
        [1, 1, 1,  ..., 1, 0, 1],
        [1, 1, 1,  ..., 1, 1, 0],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 1, 0],
        ...,
        [0, 0, 0,  ..., 0, 1, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')


## Config

In [8]:
num_epochs = 30000
checkpoint_every = 100
n_samples = 16384  # 2^14 = 16384
bit_length = 7
frac_train = 0.008

# Optimizer config
lr = 1e-3
wd = 1
betas = (0.9, 0.98)

In [9]:
# split dataset into train-test sets
torch.manual_seed(seed)
indices = torch.randperm(n_samples)
cutoff = int(n_samples*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

tensor([[1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1],
        [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0]], device='cuda:0')
tensor([[1, 0, 0, 0, 0, 0, 1],
        [1, 0, 1, 0, 1, 0, 1],
        [0, 1, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 1, 0]], device='cuda:0')
torch.Size([131, 14])
tensor([[0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1],
        [0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1]], device='cuda:0')
tensor([[0, 0, 1, 0, 0, 1, 0],
        [0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 1, 1, 0],
        [0, 1, 0, 1, 1, 1, 0]], device='cuda:0')
torch.Size([16253, 14])


## Model

In [12]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 2,
    d_model = 64, # n_heads * d_head
    d_head = 32,
    d_mlp = 256,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=2, # input always consists of 0 or 1
    d_vocab_out=2, # output always consist of 0 or 1
    n_ctx=14, # 7 bits for input A + 7 bits for input B
    init_weights=True,
    device=device,
    seed = 999,
)

model = HookedTransformer(cfg)

In [13]:
# Disabling the biases makes model easier to interpret.
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [17]:
# ------------------------------------------------------------------
# Loss function that is consistent with the bit 7-to-13 slice
# ------------------------------------------------------------------
def loss_fn(logits, labels):
    if len(logits.shape)==3:
      logits = logits[:,7:14,:]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
    return -correct_log_probs.mean()

In [18]:
# ------------------------------------------------------------------
# Accuracy function that is consistent with the bit 7-to-13 slice
# ------------------------------------------------------------------
def accuracy_fn(logits, labels):
    """
    Compute token-level accuracy for the 7-bit XOR target.
    Works for batched (B, L, V) and un-batched (L, V) logits.
    """
    # Keep only the 7 positions that correspond to the target
    if logits.dim() == 3:          # (batch, seq_len, vocab)
        logits = logits[:, 7:14, :]
    else:                          # (seq_len, vocab)
        logits = logits[7:14, :]

    preds = logits.argmax(dim=-1)          # (batch, 7)  or  (7,)
    correct = (preds == labels).float()    # element-wise comparison
    return correct.mean()                  # scalar tensor

## Model Training

In [19]:
# ------------------------------------------------------------------
# Training loop with accuracy tracking at every checkpoint
# ------------------------------------------------------------------
train_losses,  test_losses  = [], []
train_accs,    test_accs    = [], []
model_checkpoints, checkpoint_epochs = [], []

for epoch in tqdm.tqdm(range(num_epochs)):

    # ------------------- forward / backward pass ------------------
    train_logits = model(train_data)
    train_loss   = loss_fn(train_logits, train_labels)
    train_loss.backward()

    optimizer.step()
    optimizer.zero_grad()

    train_losses.append(train_loss.item())

    # ------------------- evaluation on test split -----------------
    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss   = loss_fn(test_logits, test_labels)

    test_losses.append(test_loss.item())

    # ------------------- checkpoint bookkeeping -------------------
    if (epoch + 1) % checkpoint_every == 0:
        # accuracies
        with torch.inference_mode():
            train_acc = accuracy_fn(train_logits, train_labels).item()
            test_acc  = accuracy_fn(test_logits,  test_labels).item()

        train_accs.append(train_acc)
        test_accs.append(test_acc)

        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))

        print(f"Epoch {epoch:4d} | "
              f"Train Loss {train_loss.item()} | "
              f"Test Loss {test_loss.item()} | "
              f"Train Acc {train_acc} | "
              f"Test Acc {test_acc}")

  0%|          | 0/30000 [00:00<?, ?it/s]

Epoch   99 | Train Loss 0.18306178814700577 | Test Loss 0.24012129789610476 | Train Acc 0.9182115197181702 | Test Acc 0.8943843245506287
Epoch  199 | Train Loss 0.0005480773068867601 | Test Loss 0.009368215122606709 | Train Acc 0.9999999403953552 | Test Acc 0.9963698983192444
Epoch  299 | Train Loss 0.00018905175646199556 | Test Loss 0.008018478706503636 | Train Acc 0.9999999403953552 | Test Acc 0.9962908029556274
Epoch  399 | Train Loss 5.648530365646469e-05 | Test Loss 0.007138387216611768 | Train Acc 0.9999999403953552 | Test Acc 0.9969412088394165
Epoch  499 | Train Loss 1.8314097808615913e-05 | Test Loss 0.00604788142019239 | Train Acc 0.9999999403953552 | Test Acc 0.9975565075874329
Epoch  599 | Train Loss 6.105455986420531e-06 | Test Loss 0.004952785776131844 | Train Acc 0.9999999403953552 | Test Acc 0.9983211755752563
Epoch  699 | Train Loss 2.077537797763248e-06 | Test Loss 0.0039571269896723806 | Train Acc 0.9999999403953552 | Test Acc 0.9987342953681946
Epoch  799 | Train Lo

In [24]:
subtitle = f"(frac_train={frac_train}, lr={lr}, wd={wd})"

fig = line(
    [train_losses[::100], test_losses[::100]],
    x=np.arange(0, len(train_losses), 100),
    xaxis="Epoch",
    yaxis="Loss",
    log_y=True,
    title=f"Training Curve for XOR {subtitle}", # add the hyper-params right in the title
    line_labels=["train", "test"],
    toggle_x=True,
    toggle_y=True,
)

In [28]:
# save the trained model
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)