# Grokking
Can we observe grokking on modular addition in a toy example?

## setup

In [185]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random

In [186]:
# hyperparameters
split = 0.3
n_vocab = 53
# n_vocab = 113
n_embed = 200
n_hidden = 32
epoch = 1000
learning_rate = 1e-3
device = 'cuda'


In [187]:
# make a training set
random.seed(0xdeadbeef)
X = [(i, j)for i in range(n_vocab) for j in range(n_vocab)]
random.shuffle(X)

X_train = X[: int(len(X) * split)]
X_test = X[int(len(X) * split):]

Y_train = [(i + j) % n_vocab for i, j in X_train]
Y_test =  [(i + j) % n_vocab for i, j in X_test]

# list(zip(X_train, Y_train))[:10]

In [188]:
def get_batch(name='train'):
    return {
        'train': (torch.tensor(X_train).to(device), torch.tensor(Y_train).to(device)),
        'test': (torch.tensor(X_test).to(device), torch.tensor(Y_test).to(device))
    }[name]

# get_batch('test')

## model

In [189]:
class NN(nn.Module):
    def __init__(self):
        super().__init__()

        self.embed = nn.Embedding(n_vocab, n_embed)
        self.layers = nn.Sequential(
            nn.Linear(n_embed, n_hidden),
            nn.ReLU(n_hidden),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(n_hidden),
            nn.Linear(n_hidden, n_vocab))

    def forward(self, x):
        # x shape [30%, 2]
        a, b = x[:, 0], x[:, 1]
        # compute embedding
        a_embed = self.embed(a)
        b_embed = self.embed(b)
        # merge them (could have been a torch.cat() instead?)
        embd = a_embed + b_embed
        # print(f'{a_embed.shape=}, {b_embed.shape=}')
        # embd = torch.cat((a_embed, b_embed), dim=1)
        # print(f'{embd.shape=}')
        return self.layers(embd)
    
model = NN().to(device)
# model(torch.tensor([[2, 3], [3, 4]]))

In [190]:
# 2 options for composition of the 2 inputs
# a = F.one_hot(torch.tensor(1), 3)
# b = F.one_hot(torch.tensor(2), 3)

# torch.cat((a, b))
# a + b

## train

In [191]:
# train
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1)

epoch = 1000

def evaluate(name):
    X, Y = get_batch(name)
    Y_hat = model(X)
    loss = F.cross_entropy(Y_hat, F.one_hot(Y, n_vocab).float())
    return loss

@torch.no_grad()
def evaluate_test():
    return evaluate('test')

def evaluate_train():
    return evaluate('train')

for i in range(epoch):
    loss = evaluate_train()
    # test_loss = evaluate_test()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # if i % 50 == 49:
        # print(f'{i=}, {loss=}, {test_loss=}')

print(evaluate_train())
print(evaluate_test())


tensor(0.0596, device='cuda:0', grad_fn=<DivBackward1>)
tensor(12.4371, device='cuda:0')


In [192]:
epoch = 50000

for i in range(epoch):
    loss = evaluate_train()
    test_loss = evaluate_test()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 50 == 49:
        print(f'{i=}, {loss=}, {test_loss=}')


i=49, loss=tensor(0.0569, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.4627, device='cuda:0')
i=99, loss=tensor(0.0543, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.4860, device='cuda:0')
i=149, loss=tensor(0.0519, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.4999, device='cuda:0')
i=199, loss=tensor(0.0500, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5133, device='cuda:0')
i=249, loss=tensor(0.0481, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5265, device='cuda:0')
i=299, loss=tensor(0.0462, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5464, device='cuda:0')
i=349, loss=tensor(0.0446, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5618, device='cuda:0')
i=399, loss=tensor(0.0431, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5789, device='cuda:0')
i=449, loss=tensor(0.0415, device='cuda:0', grad_fn=<DivBackward1>), test_loss=tensor(12.5993, device='cud