# Simple grokking demo with superposition

In [3]:
# Import stuff
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops

## Generate Dataset

In [6]:
# Setup
p = 97
train_frac  = 0.3
device = "cpu"

In [7]:
a_vec = einops.repeat(t.arange(p),"i -> (i j)",j=p)
b_vec = einops.repeat(t.arange(p),"j -> (i j)",i=p)
eq_vec = einops.repeat(t.tensor(p)," -> i",i= p**2)
# The dataset consists of pairs (x,y) with x = (a,b,eq) and y = a+b mod p
# we randomly permute the dataset and split it into train and test dataset
dataset = t.stack([a_vec,b_vec,eq_vec],dim=1).to(device=device)
labels = (dataset[:,0] + dataset[:,1]) % p
indices = t.randperm(p**2)
train_indices = indices[:int(train_frac*p**2)]
test_indices = indices[int(train_frac*p**2):]

train_dataset = dataset[train_indices]
train_labels = labels[train_indices]

test_dataset = dataset[test_indices]
test_labels = labels[test_indices]
print(train_dataset)
print(train_labels)

tensor([[71, 77, 97],
        [91,  9, 97],
        [33, 73, 97],
        ...,
        [ 6, 51, 97],
        [13, 39, 97],
        [75, 27, 97]])
tensor([51,  3,  9,  ..., 57, 52,  5])


## Create a simple one layer transformer model

This is sort of the easiest possible transformer model. It takes as input the one hot encoded token sequence (a,b,=), where = is set to p for convenience. The architecture is:

0. Token: the tokens $t_0,t_1,t_2$ are one hot encoded d_vocab dimensional vectors and the input sequence is $$t = (t_0,t_1,t_2)^T$$
1. Embedding: The tokens are embedded in the d_model dimensional space by a learnable matrix W_E $$x_0 = Embed(t) = t @W_E$$
2. Positional Embedding: The positional Embedding is implemented by adding a d_model dimensional vector to each embedded vector, depending on the token position. It is also learned. $$x_1 = x_0 + W_{pos}$$
3. Attention layer: A simple attention layer with num_heads = 4 attention heads
$$x_2 = x_1 + Attention(x_1)$$
4. MLP layer: A simple mlp with one hidden layer with ReLU activation function and no bias:
$$x_3 = x_2 + MLP(x_2)$$
5. Unembedding: a learned unembedding matrix W_U that maps back to the vocab.
$$x_4 = x_3 @ W_U$$

In [8]:
#Config for the transformer architecture
d_vocab = p+1   #The input consists of the numbers from 0 to p-1 and p for the equal sign
d_model = 128   #dimension of the model
n_ctx = 3       #context length (a,b,=) where = is encoded as 97
num_heads = 4
d_heads = d_model//num_heads
d_mlp = 4 * d_model
act_type = "ReLU"

In [4]:
from models.one_layer_transformer import Transformer

In [9]:
model = Transformer(d_model,d_mlp,d_heads,d_vocab,num_heads,n_ctx,act_type).to(device=device)


## Training

In [15]:
from torch.utils.data import DataLoader, TensorDataset

import tqdm.auto as tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
#training parameters
n_epoch = 30000
lr = 1e-3
wd = 1.
DATA_SEED = 346
betas = (0.9,0.98)

In [12]:
# Define optimizer
optimizer = t.optim.AdamW(model.parameters(),lr=lr, betas=betas,weight_decay=wd)

In [13]:
#Define cross entropy loss
def loss_fn(logits,labels):
    if len(logits.shape) == 3:
        logits = logits[:,-1]
    logits = logits.to(t.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1,index = labels[:,None])[:,0]
    return -correct_log_probs.mean()

In [None]:
train_losses = []
test_losses = []
for epoch in tqdm.tqdm(range(n_epoch)):
    train_logits = model(train_dataset)
    train_loss = loss_fn(train_logits,train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())
    optimizer.step()
    optimizer.zero_grad()
    with t.inference_mode():
        test_logits = model(test_dataset)
        test_loss = loss_fn(test_logits,test_labels)
        test_losses.append(test_loss.item())

    if ((epoch+1)%100)==0:
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test loss {test_loss.item()}")

  0%|          | 102/30000 [00:09<45:53, 10.86it/s]

Epoch 99 Train Loss 1.2547569312723175 Test loss 8.574755202910724


  1%|          | 202/30000 [00:18<43:30, 11.41it/s]

Epoch 199 Train Loss 0.013822463495749922 Test loss 14.162979331560935


  1%|          | 302/30000 [00:27<43:19, 11.43it/s]

Epoch 299 Train Loss 0.004671921987955941 Test loss 14.622254524056357


  1%|▏         | 402/30000 [00:37<42:56, 11.49it/s]

Epoch 399 Train Loss 0.0015216687051994378 Test loss 15.364470201620046


  2%|▏         | 502/30000 [00:46<46:49, 10.50it/s]

Epoch 499 Train Loss 0.0005052167135453146 Test loss 16.17842971119118


  2%|▏         | 600/30000 [00:55<47:02, 10.41it/s]

Epoch 599 Train Loss 0.00017019457442162453 Test loss 17.01893632245188


  2%|▏         | 701/30000 [01:04<44:31, 10.97it/s]  

Epoch 699 Train Loss 5.809391899440574e-05 Test loss 17.864302541335043


  3%|▎         | 800/30000 [01:14<48:44,  9.99it/s]

Epoch 799 Train Loss 2.017687304490978e-05 Test loss 18.683733216788905


  3%|▎         | 901/30000 [01:24<47:50, 10.14it/s]

Epoch 899 Train Loss 7.221343434953348e-06 Test loss 19.46701321674217


  3%|▎         | 1002/30000 [01:34<48:55,  9.88it/s]

Epoch 999 Train Loss 2.73839857305549e-06 Test loss 20.17284289283619


  4%|▎         | 1102/30000 [01:43<40:14, 11.97it/s]

Epoch 1099 Train Loss 1.1617990829398628e-06 Test loss 20.753033357796085


  4%|▍         | 1201/30000 [01:53<46:16, 10.37it/s]

Epoch 1199 Train Loss 5.91560957787267e-07 Test loss 21.145291375407506


  4%|▍         | 1302/30000 [02:02<48:05,  9.95it/s]

Epoch 1299 Train Loss 3.794066796278684e-07 Test loss 21.312957057784033


  5%|▍         | 1402/30000 [02:11<45:07, 10.56it/s]

Epoch 1399 Train Loss 3.0340388641008003e-07 Test loss 21.25736111161941


  5%|▌         | 1501/30000 [02:20<41:50, 11.35it/s]

Epoch 1499 Train Loss 2.808090669862008e-07 Test loss 21.053494335543952


  5%|▌         | 1601/30000 [02:29<40:00, 11.83it/s]

Epoch 1599 Train Loss 2.7546344729555363e-07 Test loss 20.796509646680743


  6%|▌         | 1701/30000 [02:38<42:17, 11.15it/s]

Epoch 1699 Train Loss 2.7377221906698215e-07 Test loss 20.51771328110459


  6%|▌         | 1801/30000 [02:47<42:24, 11.08it/s]

Epoch 1799 Train Loss 2.726842892929827e-07 Test loss 20.227458235382713


  6%|▋         | 1901/30000 [02:56<46:10, 10.14it/s]

Epoch 1899 Train Loss 2.714062681167369e-07 Test loss 19.929136759903816


  7%|▋         | 2002/30000 [03:06<43:50, 10.64it/s]

Epoch 1999 Train Loss 2.6985457253635636e-07 Test loss 19.616734280234798


  7%|▋         | 2030/30000 [03:08<39:49, 11.71it/s]

## Plot results

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(train_losses,label="Training Loss")
plt.plot(test_losses,label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Log Loss")
plt.yscale("log")
plt.title("Traing and Test loss")
plt.legend()
plt.show()