In [1]:
from itertools import cycle

import datasets
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

from sae.data import chunk_and_tokenize

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x1542644afdd0>

In [3]:
dataset = datasets.load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample", split="train[:1000]"
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = chunk_and_tokenize(dataset, tokenizer, max_seq_len=32)
model = AutoModel.from_pretrained("gpt2", device_map={"": 0})

In [4]:
class Regression(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.linear = nn.Linear(d_in, d_out)

    def forward(self, x):
        x = self.linear(x)
        return x


class NuclearNormLoss(nn.Module):
    def __init__(self, lam=0.01, rank=None):
        super().__init__()
        self.lam = lam
        self.rank = rank or 1

    def forward(self, pred, target, weight):
        mse_loss = torch.mean((pred - target) ** 2)
        _, S, _ = torch.linalg.svd(weight)
        # penalize all other ranks
        nuclear_norm = torch.sum(S[self.rank :])
        return mse_loss + self.lam * nuclear_norm, mse_loss, nuclear_norm

In [5]:
def train_and_evaluate(
    model,
    dataset,
    i,
    j,
    rank,
    batch_size=64,
    lam=0.1,
    lr=1e-2,
    steps=20,
    eval_steps=10,
    device="cuda",
):
    split_dataset = dataset.train_test_split(test_size=0.2)
    train_dataloader = DataLoader(
        split_dataset["train"], batch_size=batch_size, shuffle=True
    )
    test_dataloader = DataLoader(
        split_dataset["test"], batch_size=batch_size, shuffle=False
    )

    linear_model = Regression(d_in=model.config.n_embd, d_out=model.config.n_embd).to(
        device
    )

    nn.init.zeros_(linear_model.linear.weight)
    nn.init.zeros_(linear_model.linear.bias)

    criterion = NuclearNormLoss(lam=lam, rank=rank)
    opt = torch.optim.SGD(linear_model.parameters(), lr=lr)

    train_iter = cycle(train_dataloader)

    for step in range(steps):
        train_batch = next(train_iter)
        opt.zero_grad()

        with torch.no_grad():
            model_out = model(
                train_batch["input_ids"].to(device), output_hidden_states=True
            )
        layer_i_acts, layer_j_acts = (
            model_out.hidden_states[i],
            model_out.hidden_states[j],
        )

        pred, target = (
            linear_model(F.normalize(layer_i_acts, p=2, dim=-1)),
            F.normalize(layer_j_acts, p=2, dim=-1),
        )
        loss, l2, nnorm = criterion(pred, target, linear_model.linear.weight)

        l2_unreduced = torch.mean((pred - target) ** 2, dim=0)
        var = torch.var(target, dim=0)

        # print(l2_unreduced.mean(-1), var.mean(-1))

        fvu = (l2_unreduced / var).mean()
        
        print(f"FVU: {fvu}")

        loss.backward()
        opt.step()

    # Evaluate on test set
    linear_model.eval()
    total_fvu = 0

    with torch.no_grad():
        for i, test_batch in enumerate(test_dataloader):
            model_out = model(
                test_batch["input_ids"].to(device), output_hidden_states=True
            )
            layer_i_acts, layer_j_acts = (
                model_out.hidden_states[i],
                model_out.hidden_states[j],
            )
            pred, target = (
                linear_model(F.normalize(layer_i_acts, p=2, dim=-1)),
                F.normalize(layer_j_acts, p=2, dim=-1),
            )
            l2_unreduced = torch.mean((pred - target) ** 2, dim=0)
            var = torch.var(target, dim=0)
            fvu = (l2_unreduced / var).mean()

            total_fvu += fvu.item()

            if i >= eval_steps:
                break

    test_fvu = total_fvu / eval_steps
    test_r2 = 1 - test_fvu

    return test_r2


def create_scree_plot(model, dataset, layer_pairs, ranks):
    plt.figure(figsize=(12, 8))

    colors = ["b", "g", "r", "c", "m", "y", "k"]  # Add more colors if needed

    for idx, (i, j) in enumerate(layer_pairs):
        r2_values = []
        for rank in ranks:
            r2 = train_and_evaluate(model, dataset, i, j, rank, steps=10)
            r2_values.append(r2)

        color = colors[idx % len(colors)]
        plt.plot(ranks, r2_values, f"{color}o-", label=f"Layers {i} to {j}")

    plt.xlabel("Rank")
    plt.ylabel("R^2")
    plt.title("Scree Plot for Multiple Layer Pairs")
    plt.grid(True)
    plt.legend()
    plt.show()


# Usage
ranks = [1, 2, 4, 8, 16, 32, 384, 768]
layer_pairs = [(1, 1), (1, 4), (1, 8)]  # Add more layer pairs as needed
create_scree_plot(model, dataset, layer_pairs, ranks)

FVU: 3.9277420043945312
FVU: 3.9332618713378906
FVU: 3.9656481742858887
FVU: 3.9733245372772217
FVU: 3.9706435203552246
FVU: 4.08003568649292
FVU: 3.8771989345550537
FVU: 4.0151262283325195
FVU: 4.004677772521973
FVU: 4.017975807189941
FVU: 3.8584885597229004
FVU: 4.045174598693848
FVU: 3.9922409057617188
FVU: 4.01267671585083
FVU: 4.126893043518066
FVU: 4.110725402832031
FVU: 4.150299072265625
FVU: 3.9668264389038086
FVU: 3.9783287048339844
FVU: 3.8852291107177734
FVU: 3.9805259704589844
FVU: 4.005075454711914
FVU: 3.909733295440674
FVU: 3.857121467590332
FVU: 3.813244581222534
FVU: 3.9511611461639404
FVU: 3.895009994506836
FVU: 3.986253499984741
FVU: 4.015557289123535
FVU: 3.9707953929901123
FVU: 3.913961410522461
FVU: 3.94234037399292
FVU: 3.9815196990966797
FVU: 3.9262137413024902
FVU: 3.8375720977783203
FVU: 3.9302451610565186
FVU: 3.9152674674987793
FVU: 4.061468124389648
FVU: 4.047924995422363
FVU: 4.0532426834106445
FVU: 4.088166236877441
FVU: 3.941070079803467
FVU: 4.030582904

KeyboardInterrupt: 

<Figure size 1200x800 with 0 Axes>