# Crosscoder

- GPT-2.
- An SAE
- Get the interested logits

In [1]:
import json
import math
import pickle
import sqlite3
from datetime import timedelta

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import lightning as L
from transformers import AutoTokenizer, AutoModelForCausalLM
from lightning.pytorch.callbacks import ModelCheckpoint

from dawnet import Inspector, op
from dawnet.utils.notebook import run_in_process, is_ready
from dawnet.utils.numpy import NpyAppendArray

In [2]:
model_id = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda")

ckpt_callback = ModelCheckpoint(train_time_interval=timedelta(minutes=30))

## Get the data

In [3]:
def tokenize_jsonl_text(jsonl_file, tokenizer, context_length, start_line=0, max_line=100000):
    """Get the tokenized jsonl data depends on

    - Which tokenizer is used to parse the text
    - What is the maximum context length that we care about
    - How the targeted text is stored inside jsonl_file
    - Are there extra information we want to store
    """
    import json
    report_every = 10000
    result = []
    with open(jsonl_file, "r") as fi:
        for idx, text_line in enumerate(fi):
            if idx < start_line:
                continue
            tokens = tokenizer.encode(json.loads(text_line)["text"])
            for i in range(0, len(tokens), context_length):
                if len(tokens[i:i+context_length]) == context_length:
                    result.append(tokens[i:i+context_length])
            if idx % report_every == 0:
                print(idx, len(result))
            if idx == (max_line - 1):
                break

    print(f"Loaded {idx+1} lines into {len(result)} chunks") 
    return result

In [None]:
start, stop = 1400000, 1500000

f2 = tokenize_jsonl_text(jsonl_file="/data2/datasets/thepile/train/00.jsonl", tokenizer=tokenizer, context_length=1024, start_line=start, max_line=stop)
f2 = np.asarray(f2, dtype=np.uint16)
np.save(f"/home/john/dawnet/experiments/crosscoders/{start}_{stop}_lines.npy", f2)

## Get the hidden

In [4]:
inspector = Inspector(model)
print(inspector)

Inspector(
  (_original_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, o

In [5]:
data = np.load("/home/john/dawnet/experiments/crosscoders/100000_lines.npy", mmap_mode="r")

In [6]:
print(data.shape)

(121274, 1024)


In [8]:
batch_size = 8
layer_name = "transformer.h.9"

hidden_size = 768
layer_output = []
file_idx = 1

hook1 = inspector.add_op(name=layer_name, op=op.CacheModuleInputOutput(no_output=True))
with NpyAppendArray(f"/data2/mech/internals/{layer_name}.npy") as npaa:
    with torch.no_grad():
        for idx in tqdm(range(0, data.shape[0], batch_size)):
            input_ids = torch.tensor(data[:batch_size]).to("cuda").long()
            output = inspector._model(input_ids)
            layer_output_ = inspector.state.input[layer_name][0][0].cpu().numpy()
            layer_output.append(layer_output_)
            if (len(layer_output) * batch_size) >= 1000:
                npaa.append(np.concatenate(layer_output))
                layer_output = []

100%|█████████████████████████████████████| 15160/15160 [1:02:49<00:00,  4.02it/s]


## Get the idea

In [17]:
data1 = np.load("/home/john/dawnet/experiments/crosscoders/intermediate_transformer.h.8_1.npy", mmap_mode="r")
data2 = np.load("/home/john/dawnet/experiments/crosscoders/intermediate_transformer.h.9_1.npy", mmap_mode="r")

In [67]:
%%time
item1 = data1[0:5]
item2 = data2[0:5]
print(item1.shape)
print(item2.shape)
item = np.stack([item1, item2], axis=0)
print(item.shape)

(5, 1024, 768)
(5, 1024, 768)
(2, 5, 1024, 768)
CPU times: user 4.71 ms, sys: 4.61 ms, total: 9.32 ms
Wall time: 7.45 ms


In [41]:
item.transpose(1, 0, 2, 3).shape

(5, 2, 1024, 768)

In [9]:
class IntermediateStateDataset(torch.utils.data.Dataset):
    def __init__(self, path1, path2):
        from pathlib import Path

        self.layer1 = np.load(path1, mmap_mode="r")
        self.layer2 = np.load(path2, mmap_mode="r")

    def __len__(self):
        return self.layer1.shape[0]

    def __getitem__(self, idx):
        item1 = self.layer1[idx]
        item2 = self.layer2[idx]
        item = np.stack([item1, item2], axis=0)

        if not isinstance(idx, int):
            item = item.reshape(item.shape[0], -1, item.shape[-1])

        item = item.transpose(1, 0, 2)
        return item

In [10]:
train_dataset = IntermediateStateDataset(
    path1="/data2/mech/internals/transformer.h.8.npy",
    path2="/data2/mech/internals/transformer.h.9.npy"
)

In [11]:
train_dataset[0].shape

(1024, 2, 768)

In [12]:
train_dataset[0:6].shape

(6144, 2, 768)

In [16]:
class CrossCoder(L.LightningModule):
    def __init__(self, n_hidden, n_features, n_layers):
        super().__init__()

        self._n_hidden = n_hidden
        self._n_features = n_features
        self._n_layers = n_layers
        self.W_enc_1 = nn.Parameter(torch.empty(n_hidden, n_features))
        self.W_enc_2 = nn.Parameter(torch.empty(n_hidden, n_features))
        self.b_enc = nn.Parameter(torch.empty(n_features))
        self.W_dec = nn.Parameter(torch.empty(n_layers, n_features, n_hidden))
        self.b_dec = nn.Parameter(torch.empty(n_layers, n_hidden))

        self.loss = nn.MSELoss()
        self.reset_parameters()
        self.save_hyperparameters()

    def encode(self, x):
        """x has shape: n_batch x n_layers x n_hidden"""
        z_1 = torch.matmul(x[:,0,:], self.W_enc_1)    # n_batch, n_features
        z_2 = torch.matmul(x[:,1,:], self.W_enc_2)    # n_batch, n_features
        z = z_1 + z_2
        z = z + self.b_enc    # n_batch, n_features
        z = nn.functional.relu(z)
        return z              # n_batch, n_features

    def decode(self, a):
        """a has shape: n_batch x n_features"""
        z = torch.matmul(a, self.W_dec)   # n_layers, n_batch, n_hidden
        n_layers, n_batch, n_hidden = z.shape
        z = z.view(n_batch, n_layers, n_hidden)
        y = z + self.b_dec    # n_batch, n_layers, n_hidden
        return y

    def forward(self, x):
        """x has shape: n_layers x n_hidden"""
        a_ = self.encode(x)   # n_batch, n_features
        y = self.decode(a_)   # n_batch, n_hidden
        return a_, y

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-4)

    def training_step(self, batch, batch_nb):
        act, output = self.forward(batch[0])

        # reconstruction mse term
        reconstruction = self.loss(batch[0], output)

        # regularization term
        W_dec_norm = self.W_dec.norm(dim=2)    # n_layers, n_features
        W_dec_sum = W_dec_norm.sum(dim=0)      # n_features
        reg = W_dec_sum * act                  # n_features
        reg = reg.sum()

        # loss
        loss = reconstruction + 1e-4 * reg
        if batch_nb % 10 == 0:
            print(loss.item(), reconstruction.item(), reg.item())
        return loss

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.W_enc_1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_enc_2, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_dec, a=math.sqrt(5))

        _, fan_in = nn.init._calculate_fan_in_and_fan_out(self.W_enc_1)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.b_enc, -bound, bound)

        _, fan_in = nn.init._calculate_fan_in_and_fan_out(self.W_dec)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.b_dec, -bound, bound)

In [27]:
torch.autograd.set_grad_enabled(True)
model = CrossCoder(n_features=768*16, n_hidden=768, n_layers=2).cuda()
with torch.no_grad():
    input_ = train_dataset[:20]
    input_ = torch.tensor(input_).cuda()
    print(input_.shape)
    feat, recon = model(input_)
    print(feat.shape, recon.shape)

torch.Size([20480, 2, 768])
torch.Size([20480, 12288]) torch.Size([20480, 2, 768])


In [14]:
def collate_fn(*args, **kwargs):
    x = torch.from_numpy(np.concatenate(args[0]))
    return x

In [17]:
model = CrossCoder(n_features=768*16, n_hidden=768, n_layers=2)
trainer = L.Trainer(accelerator="gpu", callbacks=[ckpt_callback], max_epochs=2)
trainer.fit(
    model,
    train_dataloaders=[torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn)]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type    | Params | Mode 
-------------------------------------------------
0 | loss         | MSELoss | 0      | train
  | other params | n/a     | 37.8 M | n/a  
-------------------------------------------------
37.8 M    Trainable params
0         Non-trainable params
37.8 M    Total params
151.050   Total estimated model params size (MB)
1         Modules in train mode
0         Modules in eval mode


Training: |                                                 | 0/? [00:00<?, ?it/s]

99.86021423339844 28.52557373046875 713346.4375
33.731056213378906 28.338796615600586 53922.578125
29.393239974975586 28.32080078125 10724.38671875
28.7664852142334 28.281600952148438 4848.83984375
28.590755462646484 28.247955322265625 3428.009521484375
28.529359817504883 28.218460083007812 3108.991943359375
28.49751091003418 28.184743881225586 3127.66845703125
28.47687530517578 28.143505096435547 3333.7099609375
28.461158752441406 28.0928955078125 3682.636474609375
28.447568893432617 28.03264617919922 4149.2314453125
28.43524742126465 27.9636287689209 4716.18505859375
28.42396354675293 27.8878173828125 5361.45751953125
28.413883209228516 27.80856704711914 6053.1552734375
28.405242919921875 27.7294979095459 6757.44287109375
28.39820671081543 27.655128479003906 7430.77685546875
28.392702102661133 27.589080810546875 8036.20556640625
28.388437271118164 27.534343719482422 8540.9423828125
28.38503074645996 27.491897583007812 8931.330078125
28.382169723510742 27.46114730834961 9210.225585937

`Trainer.fit` stopped: `max_epochs=2` reached.
