## Исследование эффекта гроккинга

За основу взята статья Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets:
https://arxiv.org/pdf/2201.02177.pdf

## 1. Вступление

В данной статье описывается т.н. эффект гроккинга: нейросеть резко переходит от качества случайного угадывания к идеальному качеству, причём случается это сильно после точки переобучения.

Авторы данной работы наблюдают этот эффект на данных вида aob=c, где "a","b","c" - числа, а "o" - некая операция. Состовляется таблица, где строки и столбцы это всевозможные значения "a" и "b", в ячейках которой хранятся соответствующие этим "a" и "b" - "c". Далее, случайным образом стираются некоторые ячейки(то есть разбиваем выборку на train и test(пустые ячейки)). Задача состоит в том, чтобы заполнить пустые ячейки в соответствии с выше описанной операцией.

В этой научной работе авторы наблюдали этот эффект на многих операциях, но мы остановимся на нескольких из них. Тип нейросети - трансформер, в качестве оптимизатора будем использовать AdamW, поскольку данный эффект наиболее отчетливо наблюдается при его использовании.

## 2. Программная реализация

### Библиотеки:

In [None]:
from torch import nn
import torch
import numpy as np
from torch.nn.functional import cross_entropy
from torch.optim import AdamW, Adam, SGD
from torch.optim.lr_scheduler import LambdaLR
# from net import Grokformer  # net - файл с реализацией трансформера
from tqdm.notebook import tqdm
import math
import matplotlib.pyplot as plt

In [None]:
import tqdm

In [None]:
%matplotlib inline

#для четкой прорисовки графиков
%config InlineBackend.figure_format = 'svg'

In [None]:
torch.cuda.is_available()


True

In [None]:
import os
import random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
from tqdm.auto import tqdm

import matplotlib.pyplot as plt

In [None]:
!pip install einops

In [None]:
# This code was taken directly from Neel Nanda's study of grokking:
# https://colab.research.google.com/drive/1F6_1_cWXE5M7WocUcpQWp3v8z4b1jL20

class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []

    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name

    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output,
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")

    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")

    def forward(self, x):
        return x


In [None]:
# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))

    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))

    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))

    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon

    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x

# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out

# MLP Layers
class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        self.ln = LayerNorm(d_mlp, model=self.model)
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()
        assert act_type in ['ReLU', 'GeLU']

    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x)) + self.b_in
        x = self.ln(x)
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        x = self.hook_post(x)
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(self, x):
        x = self.hook_resid_mid(x + self.hook_attn_out(self.attn((self.hook_resid_pre(self.ln1(x))))))
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp(self.ln2(x))))
        return x

# Full transformer
class Transformer(nn.Module):
    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_cache=False, use_ln=True):
        super().__init__()
        self.cache = {}
        self.use_cache = use_cache

        self.embed = Embed(d_vocab, d_model)
        self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)

    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)

        x = self.ln(x)
        x = self.unembed(x)
        return x[:, -1]

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()
        def save_hook_back(tensor, name):
            cache[name+'_grad'] = tensor[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

### Функция генерации данных:
p - деление по модулю p
function - операция

In [None]:
np.random.seed(14)
torch.manual_seed(14)
torch.cuda.manual_seed(14)

In [None]:
def create_data_p(p: int, function):
    x = torch.arange(p)  # 0..p
    y = torch.arange(p)  # 0..p
    x, y = torch.cartesian_prod(x, y).T  # декартово произведение x и y
    result = function(x, y) % p
    return torch.stack([x, y, result]).T

In [None]:
def prod(a, b):  # a*b
    return a * b

In [None]:
def summ(a, b):
    return a + b

In [None]:
def sinm(a, b):  # целая часть модуля синуса от a+b
    return (abs(torch.sin(a+b))*sinp).to(int)

In [None]:
def nesim(a, b):  # несимметричная функция a*b+b*b
    return (a*b+b*b)

In [None]:
p = 97
device = torch.device("cuda:0")  # "cpu" - процессор, "cuda:0" - видеокарта
train_ratio = 0.4  # какая доля выборки уйдет на train
batch_size = 512
budget = 50000  # регулирует кол-во эпох
sinp = 3*p  # множитель для функции синуса, чтобы результат был от 0 до sinp
func = prod  # операция

Авторы статьи в качестве входных параметров для трансформера использовали токены "a","o","b","=","c", но мы будем использовать только "a", "b", "c". Как нам кажется, токены "o" и "=" никакой ценности для нейросети не несут.

In [None]:
# 1, 2, 3 столбец - "a", "b", "c" соответственно
example = create_data_p(p, func)
print(example)

tensor([[ 0,  0,  0],
        [ 0,  1,  0],
        [ 0,  2,  0],
        ...,
        [96, 94,  3],
        [96, 95,  2],
        [96, 96,  1]])


Перемешиваем выборку и разбиваем на train и val:

In [None]:
data = create_data_p(p, func)
data = data.to(device)
data_index = torch.randperm(data.shape[0], device=device)
split = int(data.shape[0] * train_ratio)
training_set = data[data_index[:split]]
validation_set = data[data_index[split:]]

In [None]:
training_set.shape

torch.Size([3763, 3])

In [None]:
validation_set

tensor([[ 0, 50,  0],
        [78, 32, 71],
        [63, 86, 83],
        ...,
        [72, 50, 11],
        [24, 60, 82],
        [54, 43, 91]], device='cuda:0')

In [None]:
net = Transformer(num_layers=2,
                    d_vocab=p,
                    d_model=128,
                    d_mlp=512,
                    d_head=32,
                    num_heads=4,
                    n_ctx=3, # context length
                    act_type='ReLU',
                    use_cache=False,
                    use_ln=True # use LayerNorm
                 ).to(device)
optimizer = SGD(net.parameters(), lr=1e-1, weight_decay=0.001)

In [None]:
# кол-во шагов оптимизации за 1 эпоху
steps_per_epoch = math.ceil(training_set.shape[0] / batch_size)

In [None]:
def get_ravel_weights(model):
    ww = []
    for par in model.parameters():
        ww.append(par.detach().cpu().data.numpy().ravel())
    return np.concatenate(ww)

In [None]:
def isbatchnorm(module):
    return issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm)


def _check_bn(module, flag):
    if isbatchnorm(module):
        flag[0] = True


def check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn(module, flag))
    return flag[0]

## Connect 100 tr 0 val and 100 tr 100 val

In [None]:
def calc_norm(model):
    return np.sqrt(sum(param.pow(2).sum().item() for param in model.parameters()))

In [None]:
def calc_si_norm(model):
    return np.sqrt(sum(param.pow(2).sum().item() if param.requires_grad else 0. for param in model.parameters()))

In [None]:
def calc_grad_norm(model):
    return np.sqrt(sum(param.grad.pow(2).sum().item() for param in model.parameters()))

In [None]:
def calc_norm_wot_last_layer(model):
    return np.sqrt(sum(param.grad.pow(2).sum().item() if not param.grad is None else 0. for param in model.parameters()))

In [None]:
def calc_norm_last_layer(model):
    return np.sqrt(sum(param.pow(2).sum().item() if param.grad is None else 0. for param in model.parameters()))

In [None]:
train_acc, val_acc, train_loss, val_loss = [], [], [], []
weights_norm, grad_norms = [], []
norms = []
effective_lr, effective_grad = [], []
mean_effictive_grad = []
mean_grad_norms = []

In [None]:
k = 0
for epoch in range(int(budget) // steps_per_epoch):
    k += 1
    # на каждой эпохе перемешиваем train
    training_set = training_set[torch.randperm(training_set.shape[0]), :]

    for data, is_train in [(training_set, True), (validation_set, False)]:

        total_acc = 0
        total_loss = 0
        net.train(is_train)

        dl = torch.split(data , batch_size, dim=0)  # делим на батчи
        for input in dl:  # input - 1 батч
            input = input.to(device)  # используем видеокарту
            with torch.set_grad_enabled(is_train):
                logits = net(input[:, :-1])  # предсказание
                loss = cross_entropy(
                    logits, input[:, -1].flatten().to(torch.long))
                total_loss += loss.item() * input.shape[0]

            if is_train:  # пересчитываем веса, вычисляя градиенты; обновляем lr
                net.zero_grad()
                loss.backward()
                optimizer.step()

                norm = calc_si_norm(net)
                grad = calc_norm_wot_last_layer(net)
                grad_norms.append(grad)
                effective_grad.append(grad * norm)
                weights_norm.append(norm)

            acc = (logits.argmax(-1) == input[:, -1]).float().mean()
            total_acc += acc.item() * input.shape[0]

        if is_train:
            train_acc.append(total_acc / training_set.shape[0])
            train_loss.append(total_loss / training_set.shape[0])
            norms.append(norm)

        else:
            val_acc.append(total_acc / validation_set.shape[0])
            val_loss.append(total_loss / validation_set.shape[0])

    effective_lr.append(optimizer.state_dict()['param_groups'][0]['lr'] / np.mean(weights_norm) ** 2)
    mean_effictive_grad.append(np.mean(effective_grad))
    mean_grad_norms.append(np.mean(grad_norms))

    effective_grad = []
    grad_norms = []
    weights_norm = []

    if train_acc[-1] == 1:
        torch.save(net, 'net_train_100.pth')
        break
    print(f'Epoch {k}: Train / Val acc: {round(train_acc[-1], 4)} / {round(val_acc[-1], 4)}')


In [None]:
#for epoch in range(int(budget) // steps_per_epoch):
while True:
    k += 1
    # на каждой эпохе перемешиваем train
    training_set = training_set[torch.randperm(training_set.shape[0]), :]

    for data, is_train in [(training_set, True), (validation_set, False)]:

        total_acc = 0
        total_loss = 0
        net.train(is_train)

        dl = torch.split(data , batch_size, dim=0)  # делим на батчи
        for input in dl:  # input - 1 батч
            input = input.to(device)  # используем видеокарту
            with torch.set_grad_enabled(is_train):
                logits = net(input[:, :-1])  # предсказание
                loss = cross_entropy(
                    logits, input[:, -1].flatten().to(torch.long))
                total_loss += loss.item() * input.shape[0]

            if is_train:  # пересчитываем веса, вычисляя градиенты; обновляем lr
                net.zero_grad()
                loss.backward()
                optimizer.step()

                norm = calc_si_norm(net)
                grad = calc_norm_wot_last_layer(net)
                grad_norms.append(grad)
                effective_grad.append(grad * norm)
                weights_norm.append(norm)

            acc = (logits.argmax(-1) == input[:, -1]).float().mean()
            total_acc += acc.item()*input.shape[0]

        if is_train:
            train_acc.append(total_acc / training_set.shape[0])
            train_loss.append(total_loss / training_set.shape[0])
            norms.append(norm)

        else:
            val_acc.append(total_acc / validation_set.shape[0])
            val_loss.append(total_loss / validation_set.shape[0])

    effective_lr.append(optimizer.state_dict()['param_groups'][0]['lr'] / np.mean(weights_norm) ** 2)
    mean_effictive_grad.append(np.mean(effective_grad))
    mean_grad_norms.append(np.mean(grad_norms))

    effective_grad = []
    grad_norms = []
    weights_norm = []

    if  val_acc[-1] == 1 or k == 20000:
        torch.save(net, 'net_val_100.pth')
        break
    print(f'Epoch {k}: Train / Val acc: {round(train_acc[-1], 4)} / {round(val_acc[-1], 4)}')


In [None]:
plt.plot(train_acc, label='train')
plt.plot(val_acc, label='val', alpha=0.7)
plt.xlabel('epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('2L Transformer with SGDopt\n  lr=1e-1, weight_decay=1e-3')
plt.grid()
plt.show()

In [None]:
plt.title('2L Transformer with SGDopt\n  lr=1e-1, weight_decay=1e-3')
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val', alpha=0.7)
plt.xlabel('epoch')
plt.ylabel('Loss')
plt.legend()

plt.grid()
plt.yscale('log')
plt.show()

In [None]:
plt.plot(effective_lr, label='SGD: lr=1e-1, weight_decay=1e-3')
plt.title('Effictive_lr')
plt.xlabel('epoch')
plt.yscale('log')
plt.legend()
plt.grid()

In [None]:
plt.plot(mean_effictive_grad, label='SGD: lr=1e-1, weight_decay=1e-3')
plt.title('Effective_grad')
plt.xlabel('epoch')
plt.yscale('log')
plt.legend()
plt.grid()

In [None]:
plt.plot(norms, label='SGD: lr=1e-1, weight_decay=1e-3')
plt.title('weights_norm')
plt.xlabel('epoch')
plt.yscale('log')
plt.legend()
plt.grid()