# experiment with SAE
https://transformer-circuits.pub/2023/monosemantic-features

In [1]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformer_lens
from tqdm.notebook import tqdm
import wandb
from functools import lru_cache

In [2]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [3]:
class CFG():
    sae_layer = 0
    n_in = 768
    n_hidden = n_in * 8
    batch_size = 1
    max_context = 600

cfg = CFG()

In [4]:
gpt2 = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
# logits, activations = gpt2.run_with_cache("Hello World")
gpt2_dataset = gpt2.load_sample_training_dataset()

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
class Dataset(Dataset):
    def __init__(self, gpt2, gpt2_dataset):
        self.gpt2 = gpt2
        self.gpt2_dataset = gpt2_dataset

    def __len__(self):
        return len(self.gpt2_dataset)

    # @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        sentence = self.gpt2_dataset[idx]
        tokens = self.gpt2.to_tokens(sentence['text'], padding_side="right")
        _, activations = self.gpt2.run_with_cache(tokens)
        return activations['mlp_out', cfg.sae_layer]
    
dataset = Dataset(gpt2, gpt2_dataset)
train_dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

In [6]:
tokens = gpt2.to_tokens([gpt2_dataset[1002]['text'], gpt2_dataset[0]['text']], padding_side="right")
tokens.shape

torch.Size([2, 1024])

In [11]:
def hook_mlp_out(gpt2):
    filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
    gpt2.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")
    return cache

cache = hook_mlp_out(gpt2)
gpt2(gpt2_dataset[:10]['text'])
cache['blocks.0.hook_mlp_out'].shape

torch.Size([10, 1024, 768])

In [9]:
with t.no_grad():
    filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
    gpt2.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

    gpt2(gpt2_dataset[:10]['text'])
    cache['blocks.0.hook_mlp_out'].shape

In [10]:
4 * 768

3072

In [17]:
sum([len(x) for x in gpt2.parameters()])

204690

In [12]:
cache['blocks.0.hook_mlp_out'].shape

torch.Size([5, 1024, 768])

In [11]:
l, c = gpt2.run_with_cache(gpt2_dataset[:5]['text'])
# [x['text'] for x in gpt2_dataset[:5]]
c['mlp_out', 0].shape

torch.Size([5, 1024, 768])

In [56]:
for d in train_dataloader:
    print(d)
    break

tensor([[[[-0.5169,  0.2836,  0.4329,  ...,  1.6439,  1.4973, -0.0420],
          [-0.7826,  0.6905,  1.7465,  ..., -0.1071,  0.1652, -0.0113],
          [-1.1851,  0.1309, -1.0010,  ...,  1.2678,  0.1318,  0.5170],
          ...,
          [ 1.0835, -1.0322, -1.0582,  ...,  0.5411, -0.9987,  0.0215],
          [ 0.8386,  1.4588, -1.9745,  ...,  0.3897,  0.3583,  0.0359],
          [ 0.6674,  0.1444, -0.6658,  ...,  0.4925,  0.2368, -0.5140]]]],
       device='cuda:0')


## model

In [14]:
class SAE(nn.Module):
    def __init__(self, n_in, n_hidden):
        super(SAE, self).__init__()
        self.b_a = nn.Parameter(t.zeros(n_in))
        self.W_e = nn.Linear(n_in, n_hidden, bias=False)
        self.b_e = nn.Parameter(t.zeros(n_hidden))
        self.W_d = nn.Linear(n_hidden, n_in, bias=False)
        self.b_d = nn.Parameter(t.zeros(n_in))

    def forward(self, act):
        x = act + self.b_a
        x = self.W_e(x) + self.b_e
        x = F.relu(x)
        hidden_activations = x
        x = self.W_d(x) + self.b_d
        x = F.relu(x)
        return x, hidden_activations

## train

In [15]:
model = SAE(n_in=cfg.n_in, n_hidden=cfg.n_hidden).to(device)
opt = t.optim.Adam(model.parameters(), lr=1e-3)

In [16]:
def train(model, opt, dataset, epochs=100, l1_factor=0.1, wnb=True):
    if wnb:
        wandb.init(project='SAE')

    for epoch in tqdm(range(epochs)):
        for act in tqdm(dataset):
            logits, hidden_activations = model(act)
            reconstruction_loss = F.mse_loss(logits, act)
            sparsity_loss = hidden_activations.abs().mean()
            loss = reconstruction_loss + l1_factor * sparsity_loss
            if wnb:
                wandb.log({"reconstruction_loss": reconstruction_loss.item()})
                wandb.log({"sparsity_loss": sparsity_loss.item()})
                wandb.log({"loss": loss.item()})
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            # TODO: validation loss
            # wandb.log({"epoch": epoch})

train(model, opt, dataset)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeluche[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 