# Grokking on modular addition with noise
The idea is to take the standard grokking setup, i.e. modular addition and to introduce some noise in the labels.

In [6]:
# Import stuff
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops

In [7]:
# Setup
p = 97
train_frac  = 0.3
device = t.accelerator.current_accelerator().type if t.accelerator.is_available() else "cpu"
device = "cpu"  # for the experiment we use cpu since mps cannot deal with float64 apparently
print(f"Using {device} device")

Using cpu device


In [8]:
# Parameters for the noise: to each label, we add independently a random number distributed as X = Bin(p=1/2,n)-n/2
n = 4
p_bin = 0.5

## Generating a noisy Dataset


In [22]:
a_vec = einops.repeat(t.arange(p),"i -> (i j)",j=p)
b_vec = einops.repeat(t.arange(p),"j -> (i j)",i=p)

# The dataset consists of pairs (x,y) with x = (a,b) and y = a+b mod p
dataset = t.stack([a_vec,b_vec],dim=1).to(device=device)


#Introduce noise:
noise = t.tensor(np.random.binomial(n=n,p=p_bin,size= labels.shape)-int(n/2)).to(device)
labels = (dataset[:,0] + dataset[:,1] + noise) % p  #Labels with noise

# we randomly permute the dataset and split it into train and test dataset
indices = t.randperm(p**2)
train_indices = indices[:int(train_frac*p**2)]
test_indices = indices[int(train_frac*p**2):]

train_dataset = dataset[train_indices]
train_labels = labels[train_indices]

test_dataset = dataset[test_indices]
test_labels = labels[test_indices]
print(train_dataset)
print(train_labels)

tensor([[55,  9],
        [51, 84],
        [45,  6],
        ...,
        [52, 10],
        [69, 30],
        [22, 12]])
tensor([65, 36, 52,  ..., 62,  0, 35])


## Model architecture

We take a simple one-layer MLP with a learned embedding and unembedding
The architecture is:

0. Token: the tokens $t_0,t_1$ are one hot encoded d_vocab dimensional vectors and the input sequence is $$t = (t_0,t_1)^T$$
1. Embedding: The tokens are embedded in the d_model dimensional space by a learnable matrix W_E $$x_0 = Embed(t) = t @W_E$$
2. Concat: We concatenate the n_ctx(=2) vectors to form a d_model * n_ctx dimensional vector $$x_1 = flat(x_0)$$
3. MLP layer: A simple mlp with one hidden layer that is d_mlp = 4*d_model*n_ctx dimensional with ReLU activation function and bias:
$$x_2 = ReLU(x_1 @ W_{in} + b_{in}) @ W_{out}$$
4. Unembedding: a learned unembedding matrix W_U that maps back to the vocab.
$$x_3 = x_2 @ W_U$$

In [23]:
#Config for the transformer architecture
d_vocab = p   #The input consists of the numbers from 0 to p-1 and p for the equal sign
d_model = 128   #dimension of the model
n_ctx = 2       #context length (a,b,=) where = is encoded as 97
d_mlp = 2 * d_model * n_ctx
act_type = "ReLU"   #or GeLU

In [24]:
from models.simple_models import SimpleMLP
import tqdm

In [25]:
simpleMLP = SimpleMLP(d_model=d_model,d_vocab=d_vocab,n_ctx=n_ctx,act_type="ReLU")

## Training


In [26]:
#training parameters
n_epoch = 20000
lr = 1e-3
wd = 1.
DATA_SEED = 346
betas = (0.9,0.98)

In [27]:
optimizer = t.optim.AdamW(simpleMLP.parameters(),lr=lr, betas=betas,weight_decay=wd)
#Define cross entropy loss
def loss_fn(logits,labels):
    if len(logits.shape) == 3:
        logits = logits[:,-1]
    logits = logits.to(t.float64)   #change to 64 to prevent slingshots
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1,index = labels[:,None])[:,0]
    return -correct_log_probs.mean()

In [None]:
train_losses, test_losses = [], []

pbar = tqdm.tqdm(range(1, n_epoch + 1), desc="Training", unit="epoch",
            dynamic_ncols=True, leave=True)

for epoch in pbar:
    # --- train ---
    simpleMLP.train()
    optimizer.zero_grad()
    train_logits = simpleMLP(train_dataset)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    optimizer.step()

    # --- eval ---
    simpleMLP.eval()
    with t.inference_mode():
        test_logits = simpleMLP(test_dataset)
        test_loss = loss_fn(test_logits, test_labels)

    train_losses.append(train_loss.item())
    test_losses.append(test_loss.item())

    # show inline, keeps a single moving bar
    pbar.set_postfix(train=f"{train_loss.item()}",
                     test=f"{test_loss.item()}")
# Final summary line
print(f"Final — Train: {train_losses[-1]:.4f} | Test: {test_losses[-1]:.4f}")

Training:  50%|█████     | 10083/20000 [05:41<05:42, 28.94epoch/s, test=14.384149091870109, train=7.701885993978802e-07]