# Grokking

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/grokking.ipynb)

In [1]:
%pip install devinterp seaborn 

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import make_evaluate

from devinterp.optim.sgld import SGLD


sns.set_palette("deep")
sns.set_style("whitegrid")

CHECKPOINTS_PATH = Path("./checkpoints/grokking")
if not os.path.exists(CHECKPOINTS_PATH):
    os.makedirs(CHECKPOINTS_PATH)

PRIMARY, SECONDARY, TERTIARY = sns.color_palette("deep")[:3]
PRIMARY_LIGHT, SECONDARY_LIGHT, TERTIARY_LIGHT = sns.color_palette("muted")[:3]

import os

os.environ["USE_TPU_BACKEND"] = "1"
import torch_xla.core.xla_model as xm

DEVICE = xm.xla_device()

  from .autonotebook import tqdm as notebook_tqdm


In [90]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import random
import hashlib
from tqdm import tqdm
from dataclasses import dataclass


@dataclass
class ExperimentParams:
    p: int = 101
    n_batches: int = 1500
    n_save_model_checkpoints: int = 25
    print_times: int = 25
    lr: float = 0.005
    batch_size: int = 16
    hidden_size: int = 16
    embed_dim: int = 16
    train_frac: float = 0.2
    random_seed: int = 0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    weight_decay: float = 0.0002


class MLP(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.embedding = nn.Embedding(params.p, params.embed_dim)
        self.linear1r = nn.Linear(params.embed_dim, params.hidden_size, bias=True)
        self.linear1l = nn.Linear(params.embed_dim, params.hidden_size, bias=True)
        self.linear2 = nn.Linear(params.hidden_size, params.p, bias=False)
        self.act = nn.GELU()
        self.vocab_size = params.p

    def forward(self, a, b):
        x1 = self.embedding(a)
        x2 = self.embedding(b)
        x1 = self.linear1l(x1)
        x2 = self.linear1r(x2)
        x = x1 + x2
        x = self.act(x)
        x = self.linear2(x)
        return x


def test(model, dataset, device):
    n_correct = 0
    total_loss = 0
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for (x1, x2), y in dataset:
            x1, x2, y = x1.to(device), x2.to(device), y.to(device)
            out = model(x1, x2)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            pred = torch.argmax(out)
            if pred == y:
                n_correct += 1
    return n_correct / len(dataset), total_loss / len(dataset)


def train(train_dataset, test_dataset, params):
    model = MLP(params).to(params.device)
    optimizer = torch.optim.Adam(
        model.parameters(), weight_decay=params.weight_decay, lr=params.lr
    )
    loss_fn = torch.nn.CrossEntropyLoss()

    train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True)

    print_every = params.n_batches // params.print_times
    checkpoint_every = None
    if params.n_save_model_checkpoints > 0:
        checkpoint_every = params.n_batches // params.n_save_model_checkpoints

    avg_loss = 0
    loss_data = []
    pbar = tqdm(total=params.n_batches, desc="Training")
    for i in range(params.n_batches):
        # Sample random batch of data
        batch = next(iter(train_loader))
        (X_1, X_2), Y = batch
        X_1, X_2, Y = X_1.to(params.device), X_2.to(params.device), Y.to(params.device)

        # Gradient update
        optimizer.zero_grad()
        out = model(X_1, X_2)
        loss = loss_fn(out, Y)
        avg_loss += loss.item()
        loss.backward()
        optimizer.step()

        if checkpoint_every and (i + 1) % checkpoint_every == 0:
            torch.save(
                model.state_dict(), f"{CHECKPOINTS_PATH}/checkpoint_batch_{i + 1}.pt"
            )

        if (i + 1) % print_every == 0:
            avg_loss /= print_every
            val_acc, val_loss = test(model, test_dataset, params.device)
            loss_data.append(
                {
                    "batch": i + 1,
                    "train_loss": avg_loss,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                }
            )
            pbar.set_postfix(
                {
                    "batch": f"{i + 1}/{params.n_batches}",
                    "train_loss": f"{avg_loss:.4f}",
                    "val_loss": f"{val_loss:.4f}",
                    "val_acc": f"{val_acc:.4f}",
                }
            )
            pbar.update(print_every)
            avg_loss = 0

    pbar.close()
    df = pd.DataFrame(loss_data)
    val_acc, val_loss = test(model, test_dataset, params.device)
    print(f"Final Val Acc: {val_acc:.4f} | Final Val Loss: {val_loss:.4f}")
    return df


def deterministic_shuffle(lst, seed):
    random.seed(seed)
    random.shuffle(lst)
    return lst


def get_all_pairs(p):
    pairs = []
    for i in range(p):
        for j in range(p):
            pairs.append((i, j))
    return set(pairs)


# def make_random_dataset(p, seed):
#     data = []
#     pairs = get_all_pairs(p)
#     for a, b in pairs:
#         out = 2 * a * p + b
#         out = hash_with_seed(out, seed) % p
#         data.append(((torch.tensor(a), torch.tensor(b)), torch.tensor(out)))
#     return data


def make_dataset(p):
    data = []
    pairs = get_all_pairs(p)
    for a, b in pairs:
        out = 2 * a * p + b
        out = out % p
        data.append(((torch.tensor(a), torch.tensor(b)), torch.tensor(out)))
    return data


def hash_with_seed(value, seed):
    m = hashlib.sha256()
    m.update(str(seed).encode("utf-8"))
    m.update(str(value).encode("utf-8"))
    return int(m.hexdigest(), 16)


def train_test_split(dataset, train_split_proportion, seed):
    l = len(dataset)
    train_len = int(train_split_proportion * l)
    idx = list(range(l))
    idx = deterministic_shuffle(idx, seed)
    train_idx = idx[:train_len]
    test_idx = idx[train_len:]
    return [dataset[i] for i in train_idx], [dataset[i] for i in test_idx]

In [91]:
params = ExperimentParams()
torch.manual_seed(params.random_seed)

dataset = make_dataset(params.p)
train_data, test_data = train_test_split(dataset, params.train_frac, params.random_seed)

df = train(train_dataset=train_data, test_dataset=test_data, params=params)
plt.plot(df["val_loss"])
plt.plot(df["train_loss"])

Training:  24%|██▍       | 360/1500 [00:08<00:27, 41.61it/s, batch=360/1500, train_loss=0.0518, val_loss=0.0431, val_acc=1.0000]

KeyboardInterrupt: 

Training:  24%|██▍       | 360/1500 [00:27<00:27, 41.61it/s, batch=360/1500, train_loss=0.0518, val_loss=0.0431, val_acc=1.0000]

## RLCT estimation hyperparameter tuning

In [None]:
# torch.save(grid_search, "../data/grokking-rlct-sweep.pt")