# Grokking at the edge of numerical stabiity
In this notebook we would like to explore a bit about grokking, from a viewpoint that is inspired by the paper [Grokking at the edge of numerical stability](https://arxiv.org/abs/2501.04697).

The initial setup is same as in simple grokking demo - trainiing a network to do modular addition.

In [1]:
# Import stuff
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops

In [2]:
# Setup
p = 97
train_frac  = 0.3
device = t.accelerator.current_accelerator().type if t.accelerator.is_available() else "cpu"
device = "cpu"  # for the experiment we use cpu since mps cannot deal with float 64 apparently
print(f"Using {device} device")

Using cpu device


## Dataset
The dataset consists of tupels of (a,b) and labels c, where a+b = c mod p

In [3]:
a_vec = einops.repeat(t.arange(p),"i -> (i j)",j=p)
b_vec = einops.repeat(t.arange(p),"j -> (i j)",i=p)

# The dataset consists of pairs (x,y) with x = (a,b) and y = a+b mod p
# we randomly permute the dataset and split it into train and test dataset
dataset = t.stack([a_vec,b_vec],dim=1).to(device=device)
labels = (dataset[:,0] + dataset[:,1]) % p
indices = t.randperm(p**2)
train_indices = indices[:int(train_frac*p**2)]
test_indices = indices[int(train_frac*p**2):]

train_dataset = dataset[train_indices]
train_labels = labels[train_indices]

test_dataset = dataset[test_indices]
test_labels = labels[test_indices]
print(train_dataset)
print(train_labels)

tensor([[27, 44],
        [73, 24],
        [31, 57],
        ...,
        [49, 67],
        [34, 81],
        [74, 67]])
tensor([71,  0, 88,  ..., 19, 18, 44])


## Model architecture

We take a simple one-layer MLP with a learned embedding and unembedding
The architecture is:

0. Token: the tokens $t_0,t_1$ are one hot encoded d_vocab dimensional vectors and the input sequence is $$t = (t_0,t_1)^T$$
1. Embedding: The tokens are embedded in the d_model dimensional space by a learnable matrix W_E $$x_0 = Embed(t) = t @W_E$$
2. Concat: We concatenate the n_ctx(=2) vectors to form a d_model * n_ctx dimensional vector $$x_1 = flat(x_0)$$
3. MLP layer: A simple mlp with one hidden layer that is d_mlp = 4*d_model*n_ctx dimensional with ReLU activation function and bias:
$$x_2 = ReLU(x_1 @ W_{in} + b_{in}) @ W_{out}$$
4. Unembedding: a learned unembedding matrix W_U that maps back to the vocab.
$$x_3 = x_2 @ W_U$$

In [4]:
#Config for the transformer architecture
d_vocab = p   #The input consists of the numbers from 0 to p-1 and p for the equal sign
d_model = 128   #dimension of the model
n_ctx = 2       #context length (a,b,=) where = is encoded as 97
d_mlp = 2 * d_model * n_ctx
act_type = "ReLU"   #or GeLU

In [None]:
from models.simple_models import SimpleMLP
