# experiment with SAE
https://transformer-circuits.pub/2023/monosemantic-features

In [7]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import transformer_lens
from tqdm.notebook import tqdm
import wandb
from functools import lru_cache

In [8]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [9]:
class CFG():
    sae_layer = 0
    n_in = 768
    n_hidden = n_in * 8

cfg = CFG()

In [10]:
gpt2 = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
# logits, activations = gpt2.run_with_cache("Hello World")
gpt2_dataset = gpt2.load_sample_training_dataset()

Loaded pretrained model gpt2-small into HookedTransformer


In [11]:
sentence = gpt2_dataset[0]
logits, activations = gpt2.run_with_cache(sentence['text'])
activations

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [12]:
activations['mlp_out', cfg.sae_layer].shape

torch.Size([1, 1024, 768])

In [13]:
class Dataset(Dataset):
    def __init__(self, gpt2, gpt2_dataset):
        self.gpt2 = gpt2
        self.gpt2_dataset = gpt2_dataset

    def __len__(self):
        return len(self.gpt2_dataset)

    @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        sentence = self.gpt2_dataset[idx]
        _, activations = self.gpt2.run_with_cache(sentence['text'])
        return activations['mlp_out', cfg.sae_layer]
    
dataset = Dataset(gpt2, gpt2_dataset)

## model

In [14]:
class SAE(nn.Module):
    def __init__(self, n_in, n_hidden):
        super(SAE, self).__init__()
        self.b_a = nn.Parameter(t.zeros(n_in))
        self.W_e = nn.Linear(n_in, n_hidden, bias=False)
        self.b_e = nn.Parameter(t.zeros(n_hidden))
        self.W_d = nn.Linear(n_hidden, n_in, bias=False)
        self.b_d = nn.Parameter(t.zeros(n_in))

    def forward(self, act):
        x = act + self.b_a
        x = self.W_e(x) + self.b_e
        x = F.relu(x)
        hidden_activations = x
        x = self.W_d(x) + self.b_d
        x = F.relu(x)
        return x, hidden_activations

## train

In [15]:
model = SAE(n_in=cfg.n_in, n_hidden=cfg.n_hidden).to(device)
opt = t.optim.Adam(model.parameters(), lr=1e-3)

In [16]:
def train(model, opt, dataset, epochs=100, l1_factor=0.1, wnb=True):
    if wnb:
        wandb.init(project='SAE')

    for epoch in tqdm(range(epochs)):
        for act in dataset:
            logits, hidden_activations = model(act)
            reconstruction_loss = F.mse_loss(logits, act)
            sparsity_loss = hidden_activations.abs().mean()
            loss = reconstruction_loss + l1_factor * sparsity_loss
            if wnb:
                wandb.log({"reconstruction_loss": reconstruction_loss.item()})
                wandb.log({"sparsity_loss": sparsity_loss.item()})
                wandb.log({"loss": loss.item()})
            opt.zero_grad()
            loss.backward()
            opt.step()
        if wnb:
            # TODO: validation loss
            # wandb.log({"epoch": epoch})

train(model, opt, dataset)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeluche[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/100 [00:00<?, ?it/s]