# experiment with SAE
https://transformer-circuits.pub/2023/monosemantic-features

In [5]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformer_lens
from tqdm.notebook import tqdm
import wandb
from functools import lru_cache

In [6]:
device = t.device("mps" if t.backends.mps.is_available() else "cpu")
# device = t.device("cpu")

In [7]:
class CFG():
    sae_layer = 0
    n_in = 768
    n_hidden = n_in * 8
    batch_size = 64
    max_context = 600

cfg = CFG()
print(cfg.n_hidden)

6144


In [8]:
gpt2 = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
# logits, activations = gpt2.run_with_cache("Hello World")
gpt2_dataset = gpt2.load_sample_training_dataset()

Loaded pretrained model gpt2-small into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [9]:
class Dataset(Dataset):
    def __init__(self, gpt2, gpt2_dataset):
        self.gpt2 = gpt2
        self.gpt2_dataset = gpt2_dataset

    def __len__(self):
        return len(self.gpt2_dataset)

    # @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        sentence = self.gpt2_dataset[idx]
        tokens = self.gpt2.to_tokens(sentence['text'], padding_side="right")
        _, activations = self.gpt2.run_with_cache(tokens)
        return activations['mlp_out', cfg.sae_layer]
    
dataset = Dataset(gpt2, gpt2_dataset)
train_dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

In [10]:
tokens = gpt2.to_tokens([gpt2_dataset[1002]['text'], gpt2_dataset[0]['text']], padding_side="right")
tokens.shape

torch.Size([2, 1024])

In [11]:
# def hook_mlp_out(gpt2):
#     filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
#     gpt2.reset_hooks()
#     cache = {}
#     def forward_cache_hook(act, hook):
#         cache[hook.name] = act.detach()
#     gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")
#     return cache

# cache = hook_mlp_out(gpt2)
# gpt2(gpt2_dataset[:10]['text'])
# cache['blocks.0.hook_mlp_out'].shape

In [12]:
def create_dataset(batch_size=32, write_after=1000):
    filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
    gpt2.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

    output = []
    with t.no_grad():
        for slice in tqdm(range(0, len(gpt2_dataset), batch_size)):
            print("looking at slice", slice)
            if slice % 1000 < 32 and slice > 0:
                print("writing to disk")
                t.save(t.cat(output, dim=0), f"activations_{slice}.pt")
                output = []

            if slice+batch_size >= len(gpt2_dataset):
                sliced_dataset = gpt2_dataset[slice:]['text']
            else:
                sliced_dataset = gpt2_dataset[slice:slice+batch_size]['text']

            gpt2(sliced_dataset)
            activations_for_batch = cache['blocks.0.hook_mlp_out']
            output.append(activations_for_batch)
    t.save(t.cat(output, dim=0), f"activations_final.pt")

# create_dataset()

In [13]:
import os
files = [fname for fname in os.listdir('./') if 'activations' in fname]
activation_tensors = [t.load(fname) for fname in files[:5]]
activation_tensors = t.cat(activation_tensors, dim=0)
print(activation_tensors.shape) # should be (10000, 1024, 768)

class Dataset(Dataset):
    def __init__(self, activation_tensors):
        self.activation_tensors = activation_tensors

    def __len__(self):
        return len(self.activation_tensors)

    def __getitem__(self, idx):
        return self.activation_tensors[idx]
    
train_dataloader = DataLoader(t.utils.data.TensorDataset(activation_tensors), batch_size=cfg.batch_size, shuffle=True)
for batch in train_dataloader:
    print(batch[0].shape) # should be (1, 1024, 768)
    break

torch.Size([5040, 1024, 768])
torch.Size([64, 1024, 768])


## model

In [14]:
class SAE(nn.Module):
    def __init__(self, n_in, n_hidden):
        super(SAE, self).__init__()
        self.b_a = nn.Parameter(t.zeros(n_in))
        self.b_d = nn.Parameter(t.zeros(n_in))
        self.b_e = nn.Parameter(t.zeros(n_hidden))

        self.W_e = nn.Linear(n_in, n_hidden, bias=False)
        self.W_d = nn.Linear(n_hidden, n_in, bias=False)

    def forward(self, act):
        x = act + self.b_a
        x = self.W_e(x) + self.b_e
        x = F.relu(x)
        hidden_activations = x
        x = self.W_d(x) + self.b_d
        return x, hidden_activations


## train

In [15]:
model = SAE(n_in=cfg.n_in, n_hidden=cfg.n_hidden).to(device)
opt = t.optim.Adam(model.parameters(), lr=1e-3)
def train(model, opt, dataloader, epochs=100, l1_factor=0.1, wnb=True):
    if wnb:
        wandb.init(project='SAE')

    for epoch in tqdm(range(epochs)):
        for batch in tqdm(dataloader):
            batch = batch[0].to(device)
            logits, hidden_activations = model(batch)
            reconstruction_loss = F.mse_loss(logits, batch)
            sparsity_loss = hidden_activations.abs().mean()
            loss = reconstruction_loss + l1_factor * sparsity_loss
            if wnb:
                wandb.log({"reconstruction_loss": reconstruction_loss.item()})
                wandb.log({"sparsity_loss": sparsity_loss.item()})
                wandb.log({"loss": loss.item()})
            opt.zero_grad()
            loss.backward()
            opt.step()
        # TODO: validation loss
        # if wnb:
            # wandb.log({"epoch": epoch})

train(model, opt, train_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mturbochardo[0m. Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` 

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
t.save(model.state_dict(), 'model.pt')

In [29]:
fact_question_dataset = [
    {
        "fact": "The heart of a blue whale is so big, a human could swim through its",
        "question": "Is it true that a human could swim through the arteries of a blue whale's heart?"
    },
    {
        "fact": "Sloths can hold their breath for up to 40",
        "question": "Can sloths really hold their breath for up to 40 minutes?"
    },
    {
        "fact": "A snail can sleep for up to three years.",
        "question": "Can a snail sleep for up to three years?"
    },
    {
        "fact": "Penguins have knees.",
        "question": "Do penguins really have knees?"
    },
    {
        "fact": "The male emperor penguin protects the eggs from the cold.",
        "question": "Is it the male emperor penguin that protects the eggs from the cold?"
    },
    {
        "fact": "The largest land animal is the African elephant.",
        "question": "Is the African elephant the largest land animal?"
    },
    {
        "fact": "Kangaroos use their tails for balance.",
        "question": "Do kangaroos use their tails for balance?"
    },
    {
        "fact": "A group of flamingos is called a 'flamboyance.'",
        "question": "Is a group of flamingos called a 'flamboyance'?"
    },
    {
        "fact": "Rabbits and hares are different species.",
        "question": "Are rabbits and hares different species?"
    },
    {
        "fact": "Cows have four stomachs.",
        "question": "Do cows have four stomachs?"
    }
]




In [40]:
filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
gpt2.reset_hooks()
cache = {}
def forward_cache_hook(act, hook):
    with t.no_grad():
        logits, hidden_activations = model(act)
        cache[hook.name] = hidden_activations.detach()
        return logits

gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

output = gpt2(fact_question_dataset[0]['fact']) 

most_probable_token = t.argmax(output[0, -1]).item()
gpt2.to_single_str_token(most_probable_token)




' mouth'