# experiment with SAE
https://transformer-circuits.pub/2023/monosemantic-features

In [1]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformer_lens
from tqdm.notebook import tqdm
import wandb
from functools import lru_cache
from dataclasses import dataclass
from jaxtyping import Float, Int
import einops

if t.backends.mps.is_available():
    device = t.device("mps")
elif t.backends.cuda.is_available():
    device = t.device("cuda")
else:
    device = t.device("cpu")

In [2]:
@dataclass
class CFG:
    sae_layer: int = 0
    n_in: int = 3072
    n_hidden: int = n_in * 8
    batch_size: int = 64
    max_context: int = 600

cfg = CFG()
print(cfg.n_hidden)

24576


In [3]:
gpt2 = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
gpt2_dataset = gpt2.load_sample_training_dataset()

Loaded pretrained model gpt2-small into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [4]:
def create_dataset(layer=0, batch_size=32, write_after=1000):
    filter_mlp_out = lambda name: ("blocks.0.mlp.hook_post" == name)
    gpt2.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

    output = []
    with t.no_grad():
        for slice in tqdm(range(0, len(gpt2_dataset), batch_size)):
            print("looking at slice", slice)
            if slice % write_after < batch_size and slice > 0:
                print("writing to disk")
                t.save(t.cat(output, dim=0), f"new_activations_{slice}.pt")
                output = []

            if slice+batch_size >= len(gpt2_dataset):
                sliced_dataset = gpt2_dataset[slice:]['text']
            else:
                sliced_dataset = gpt2_dataset[slice:slice+batch_size]['text']

            gpt2(sliced_dataset)
            activations_for_batch = cache['blocks.0.mlp.hook_post']
            output.append(activations_for_batch)
    t.save(t.cat(output, dim=0), f"new_activations_final.pt")

# create_dataset(batch_size=64)

In [5]:
import os
files = [fname for fname in os.listdir('./') if 'new_activations' in fname]
activation_tensors = [t.load(fname, mmap=True) for fname in files[:3]] # only load the first 5 * 1024 activations because loading all of them all crashes the kernel


In [7]:
class Dataset(Dataset):
    def __init__(self, activation_tensors):
        self.activation_tensors = activation_tensors

    def __len__(self):
        return len(self.activation_tensors)

    def __getitem__(self, idx):
        return self.activation_tensors[idx]

concat = t.utils.data.ConcatDataset([Dataset(act) for act in activation_tensors]) 
train_dataloader = DataLoader(concat, batch_size=cfg.batch_size, shuffle=False)

for batch in train_dataloader:
    print(batch[0].shape) # should be (64, 1024, 768)
    break

torch.Size([1024, 3072])


## model

In [8]:
class SAE(nn.Module):
    def __init__(self, n_in, n_hidden):
        super(SAE, self).__init__()
        self.b_a = nn.Parameter(t.zeros(n_in))
        self.b_d = nn.Parameter(t.zeros(n_in))
        self.b_e = nn.Parameter(t.zeros(n_hidden))

        self.W_e = nn.Linear(n_in, n_hidden, bias=False)
        self.W_d = nn.Linear(n_hidden, n_in, bias=False)

    def forward(self, act):
        x = act - self.b_d
        x = self.W_e(x) + self.b_e
        x = F.relu(x)
        hidden_activations = x
        x = self.W_d(x) + self.b_d
        return x, hidden_activations


## train

In [10]:
model = SAE(n_in=cfg.n_in, n_hidden=cfg.n_hidden).to(device)
opt = t.optim.Adam(model.parameters(), lr=1e-3)
def train(model, opt, dataloader, epochs=100, l1_factor=0.1, wnb=True):
    if wnb:
        wandb.init(project='SAE')

    for epoch in tqdm(range(epochs)):
        for batch in tqdm(dataloader):
            batch = batch[0].to(device)
            logits, hidden_activations = model(batch)
            print((logits - batch).shape)
            reconstruction_loss = (logits - batch).pow(2).mean(-1)
            sparsity_loss = hidden_activations.abs().sum(-1)
            loss = (reconstruction_loss + l1_factor * sparsity_loss).mean(0).sum()
            if wnb:
                # wandb.log({"reconstruction_loss": reconstruction_loss.item()})
                # wandb.log({"sparsity_loss": sparsity_loss.item()})
                wandb.log({"loss": loss.item()})
            opt.zero_grad()
            loss.backward()
            opt.step()

train(model, opt, train_dataloader)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167812044408896, max=1.0…

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/46 [00:00<?, ?it/s]

torch.Size([1024, 3072])


RuntimeError: a Tensor with 1024 elements cannot be converted to Scalar

In [None]:
fact_question_dataset = [
    {
        "fact": "The heart of a blue whale is so big, a human could swim through its arteries.",
        "question": "Is it true that a human could swim through the arteries of a blue whale's heart?"
    },
    {
        "fact": "Sloths can hold their breath for up to 40 minutes.",
        "question": "Can sloths really hold their breath for up to 40 minutes?"
    },
    {
        "fact": "A snail can sleep for up to three years.",
        "question": "Can a snail sleep for up to three years?"
    },
    {
        "fact": "Penguins have knees.",
        "question": "Do penguins really have knees?"
    },
    {
        "fact": "The male emperor penguin protects the eggs from the cold.",
        "question": "Is it the male emperor penguin that protects the eggs from the cold?"
    },
    {
        "fact": "The largest land animal is the African elephant.",
        "question": "Is the African elephant the largest land animal?"
    },
    {
        "fact": "Kangaroos use their tails for balance.",
        "question": "Do kangaroos use their tails for balance?"
    },
    {
        "fact": "A group of flamingos is called a 'flamboyance.'",
        "question": "Is a group of flamingos called a 'flamboyance'?"
    },
    {
        "fact": "Rabbits and hares are different species.",
        "question": "Are rabbits and hares different species?"
    },
    {
        "fact": "Cows have four stomachs.",
        "question": "Do cows have four stomachs?"
    }
]

In [None]:
filter_mlp_out = lambda name: ("blocks.0.mlp.hook_post" == name)
gpt2.reset_hooks()
cache = {}
def forward_cache_hook(act, hook):
    with t.no_grad():
        logits, hidden_activations = model(act)
        cache[hook.name] = hidden_activations.detach()
        return logits

gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

questions = [pair["question"] for pair in fact_question_dataset]
facts = [pair["fact"] for pair in fact_question_dataset]
all_sentences = facts + questions

gpt2(all_sentences)
activations = cache['blocks.0.mlp.hook_post']
print(activations.shape)

questions, facts = t.split(activations, 10, dim=0)
print(f"questions: {questions.shape}, facts: {facts.shape}")

max_neuron_for_fact = facts.amax(dim=(0, 1)) # (n_batch, n_words, n_neurons) 
min_neuron_for_question = questions.amin(dim=(0, 1)) # (n_batch, n_words, n_neurons) 

print(f"shapes: {max_neuron_for_fact.shape=}, {min_neuron_for_question.shape=}")

differences_by_neuron = (max_neuron_for_fact - min_neuron_for_question)/min_neuron_for_question
most_different_neuron = differences_by_neuron.argmax()

print(questions.mean(dim=(0, 1))[most_different_neuron])
print(facts.mean(dim=(0, 1))[most_different_neuron])

In [98]:
@t.no_grad()
def get_feature_probability(
    text: list[str], # n_batch
    model: transformer_lens.HookedTransformer,
    autoencoder: SAE,
) -> Float[t.Tensor, "instance n_hidden_ae"]:
    '''
    Returns the feature probabilities (i.e. fraction of time the feature is active) for each feature in the
    autoencoder, averaged over all `batch * seq` tokens.
    '''
    logits, cache = model.run_with_cache(text, names_filter=['blocks.0.mlp.hook_post']) 
    print(cache.keys())
    hidden_activations = cache['blocks.0.mlp.hook_post'] # n_batch, n_seq, d_model
    rearranged = einops.rearrange(hidden_activations, 'b s d -> (b s) d') # (n_batch * n_seq, d_model)
    print(f"Rearranged shape: {rearranged.shape=}")
    sae_activations = autoencoder(rearranged) # (n_batch * n_seq, d_model * 8)
    return sae_activations.mean(0)

# Get a batch of feature probabilities & average them (so we don't put strain on the GPU)
print(gpt2.cfg)

feature_probability = [
    get_feature_probability(gpt2_dataset[i:i+32]["text"], gpt2, model)
    for i in tqdm(range(0, 1000, 32))
]
# feature_probability = sum(feature_probability) / len(feature_probability)

# log_freq = (feature_probability + 1e-10).log10()

# # Visualise sparsities for each instance
# for i, lf in enumerate(log_freq):
#     hist(
#         lf,
#         title=f"Instance #{i+1}: Log Frequency of Features",
#         labels={"x": "log<sub>10</sub>(freq)"},
#         histnorm="percent",
#         template="ggplot2"
#     )

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': N

  0%|          | 0/32 [00:00<?, ?it/s]

dict_keys(['blocks.0.mlp.hook_post'])
Rearranged shape: rearranged.shape=torch.Size([32768, 3072])


RuntimeError: The size of tensor a (3072) must match the size of tensor b (768) at non-singleton dimension 1