# Grokking

In [6]:
from typing import Dict, Optional, Any
import logging
import sys


import torch
from torch import nn, optim
from torch.nn import functional as F
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import sys
sys.path.append("/home/paperspace/devinterp/")
from devinterp.zoo.arithmetic import ModularArithmeticConfig, ModularArithmetic
from devinterp.zoo.transformer import TransformerConfig, Transformer
from devinterp.slt.sampler import estimate_rlct

from devinfra.learner import LearnerConfig, Learner
from devinfra.utils.device import get_default_device
from devinfra.utils.tensors import Reduction
from devinfra.evals import CombineEvaluators, Evaluator, RepeatEvaluator
from devinfra.optim.schedulers import LRScheduler


device = get_default_device()
logging.basicConfig(level=logging.INFO)

In [7]:
# Data

MODULUS = 113

trainset, testset = ModularArithmeticConfig(
    operator="/",
    modulus=MODULUS,
    seed=0,
    split=0.4
).factory_split()

In [8]:
# Evals

def cross_entropy_last_token(outputs, targets, reduction: Reduction = "sum"):
    """
    Wrapper around cross entropy loss because we only care about the last number predicted.
    """
    # Only look at predictions of last numbers
    outputs = outputs[:, -1]

    # Compute individual and summed losses for final number
    logprobs = F.log_softmax(outputs.to(torch.float32), dim=-1)
    prediction_logprobs = torch.gather(logprobs, index=targets.unsqueeze(1), dim=-1)

    if reduction == "mean":
        loss = -torch.mean(prediction_logprobs)
    elif reduction == "sum":
        loss = -torch.sum(prediction_logprobs)
    else:
        raise ValueError("Invalid reduction argument.")

    return loss

trainloader = torch.utils.data.DataLoader(trainset, batch_size=1024, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, shuffle=False)


def eval_loss_and_acc(model: nn.Module, *_) -> Dict[str, float]:
    model.eval()

    results = {}

    for name, loader in zip(["train", "test"], [trainloader, testloader]):
        total = 0
        correct = 0

        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            with torch.no_grad():
                y_hat = model(x)

            total += cross_entropy_last_token(y_hat, y, reduction="sum").item()
            correct += (y_hat[:, -1, :].max(dim=1).indices == y).sum().item()  # argmax doesn't work for device=mps

        results[f"{name}/loss"] = total / len(loader.dataset)
        results[f"{name}/accuracy"] = correct / len(loader.dataset)

    return results


def eval_rlct(model: nn.Module, *_):
    optimizer_kwargs = dict(
        lr=1e-7, noise_level=1., weight_decay=3e-7, elasticity=10., temperature="adaptive", num_samples=len(trainset)
    )
    return {
        "rlct": estimate_rlct(model, trainloader, cross_entropy_last_token, 'sgld', optimizer_kwargs, num_draws=20, num_chains=5, num_burnin_steps=0, num_steps_bw_draws=1, cores=1, pbar=False)
    }


evals = CombineEvaluators([
    eval_loss_and_acc,
    RepeatEvaluator(eval_rlct, 5),
])


In [9]:

model_config = TransformerConfig(d_vocab=MODULUS + 1)
model = model_config.factory().to(device)

learner_config = LearnerConfig(
    num_training_samples=len(trainset),
    batch_size=256,
    num_steps=25_000,
    criterion="cross_entropy",
    device=device,
    optimizer_config={
        "optimizer_type": "AdamW",
        "lr": 1e-3,
        "weight_decay": 0.2,
        "betas":(0.9, 0.98),
    },
    logger_config={
        # "project": "grokking",
        # "entity": "devinterp",
        "logging_steps": {
            "log_space": (1, 25, 1),
            "linear_space":  (0, 100, 4),
        },
        "use_std": True
   },
   checkpointer_config={
        "checkpoint_steps": {
            "log_space": (1, 25, 1),
            "linear_space":  (0, 100, 4),
        },
        # "bucket": "devinterp",
        "project_dir": "div-mod-113",
        "local_root": "../"
   },
)
learner = Learner(
    model=model,
    dataset=trainset,
    evaluator=evals,
    config=learner_config
)


num_training_samples=5107 batch_size=256 run_name=None num_steps=25000 logger_config=MetricLoggingConfig(project=None, entity=None, logging_steps=(0...100) 5 steps, out_file=None, use_df=False, stdout=False, run_id=None) checkpointer_config=CheckpointerConfig(bucket_name=None, project_dir=div-mod-113, local_root=../) optimizer_config=OptimizerConfig(optimizer_type='AdamW', lr=0.001, weight_decay=0.2, momentum=None, betas=(0.9, 0.98), noise_level=None, elasticity=None, temperature=None, num_samples=None) scheduler_config=None device=device(type='cuda') criterion='cross_entropy'


In [10]:
# learner.save_checkpoint(0)
learner.train()

Training...:   0%|          | 0/25000 [00:00<?, ?it/s]

RuntimeError: Expected target size [256, 114], got [256]