In [None]:
%pip -q install transformer_lens
%pip -q install circuitsvis

## Import Libraries

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd

import einops
#from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_xor1.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

## DATASET

In [None]:
from itertools import product

def generate_xor_dataset(bit_length=1, n_samples=None, device='cpu', seed=None):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
    if n_samples is None:
        all_pairs = list(product([0, 1], repeat=2 * bit_length))
        #random.shuffle(all_pairs)
        inputs = torch.tensor(all_pairs, dtype=torch.long, device=device)
    else:
        inputs = torch.randint(
            0, 2, size=(n_samples, 2 * bit_length), dtype=torch.long, device=device
        )

    a = inputs[:, :bit_length]
    b = inputs[:, bit_length:]
    targets = torch.bitwise_xor(a, b)

    return inputs, targets
seed = 598
dataset, labels = generate_xor_dataset(bit_length=4, device=device, seed=seed)

In [None]:
print(dataset.shape)
print(labels.shape)
print(dataset)
print(labels)

## Config

In [None]:
frac_train = 0.115
# Optimizer config
lr = 1e-3
wd = 1
betas = (0.9, 0.98)

num_epochs = 30000
checkpoint_every = 100
n_samples = 256
bit_length = 4

In [None]:
torch.manual_seed(seed)
indices = torch.randperm(n_samples)
cutoff = int(n_samples*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

## Model

In [None]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 2,
    d_model = 64,
    d_head = 32,
    d_mlp = 256,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=2,
    d_vocab_out=2,
    n_ctx=8,
    init_weights=True,
    device=device,
    seed = 999,
)

In [None]:
model = HookedTransformer(cfg)

In [None]:
# Disable the biases, as we don't need them for this task and it makes things easier to interpret.

for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
logits = model(train_data)
print(logits.shape)
log_probs = logits[:,-1,:].log_softmax(dim=-1)
print(log_probs.shape)
print(train_labels.shape)
print(train_labels.unsqueeze(-1).shape)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
      logits = logits[:,4:8,:]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
    return -correct_log_probs.mean()
train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

## Training

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    optimizer.step()
    optimizer.zero_grad()

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())

    if ((epoch+1)%checkpoint_every)==0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
subtitle = f"(frac_train={frac_train}, lr={lr}, wd={wd})"

fig = line(
    [train_losses[::100], test_losses[::100]],
    x=np.arange(0, len(train_losses), 100),
    xaxis="Epoch",
    yaxis="Loss",
    log_y=True,
    # add the hyper-params right in the title
    title=f"Training Curve for XOR {subtitle}",
    line_labels=["train", "test"],
    toggle_x=True,
    toggle_y=True,
)

In [None]:
# save the trained model

torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [None]:
# 1) Delete the old checkpoint file, if you saved one:
if os.path.exists(PTH_LOCATION):
    os.remove(PTH_LOCATION)
    print(f"Removed checkpoint file at {PTH_LOCATION}")

# 2) Clear out your history lists:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
print("Cleared loss and checkpoint history.")