# experiment with SAE
https://transformer-circuits.pub/2023/monosemantic-features

In [None]:
%pip list

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformer_lens
from tqdm.notebook import tqdm
import wandb
from functools import lru_cache

In [None]:
device = t.device("mps" if t.backends.mps.is_available() else "cpu")
# device = t.device("cpu")
# print(t.backends.mps.is_available())

In [None]:
class CFG():
    sae_layer = 0
    n_in = 3072
    n_hidden = n_in * 8
    batch_size = 16
    max_context = 600

cfg = CFG()
print(cfg.n_hidden)

In [None]:
gpt2 = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
# logits, activations = gpt2.run_with_cache("Hello World")
gpt2_dataset = gpt2.load_sample_training_dataset()

In [None]:
logits, cache = gpt2.run_with_cache("a black cat")

In [None]:
cache

In [None]:
class Dataset(Dataset):
    def __init__(self, gpt2, gpt2_dataset):
        self.gpt2 = gpt2
        self.gpt2_dataset = gpt2_dataset

    def __len__(self):
        return len(self.gpt2_dataset)

    # @lru_cache(maxsize=None)
    def __getitem__(self, idx):
        sentence = self.gpt2_dataset[idx]
        tokens = self.gpt2.to_tokens(sentence['text'], padding_side="right")
        _, activations = self.gpt2.run_with_cache(tokens)
        return activations['mlp_out', cfg.sae_layer]
    
dataset = Dataset(gpt2, gpt2_dataset)
train_dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

In [None]:
tokens = gpt2.to_tokens([gpt2_dataset[1002]['text'], gpt2_dataset[0]['text']], padding_side="right")
tokens.shape

In [None]:
# def hook_mlp_out(gpt2):
#     filter_mlp_out = lambda name: ("blocks.0.hook_mlp_out" == name)
#     gpt2.reset_hooks()
#     cache = {}
#     def forward_cache_hook(act, hook):
#         cache[hook.name] = act.detach()
#     gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")
#     return cache

# cache = hook_mlp_out(gpt2)
# gpt2(gpt2_dataset[:10]['text'])
# cache['blocks.0.hook_mlp_out'].shape

In [None]:
def create_dataset(batch_size=32, write_after=1000):
    filter_mlp_out = lambda name: ("blocks.0.mlp.hook_post" == name)
    gpt2.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
        raise Exception("end execution")
    gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

    output = []
    with t.no_grad():
        for slice in tqdm(range(0, len(gpt2_dataset), batch_size)):
            print("looking at slice", slice)
            if slice % write_after < batch_size and slice > 0:
                print("writing to disk")
                t.save(t.cat(output, dim=0), f"activations_{slice}.pt")
                output = []

            if slice+batch_size >= len(gpt2_dataset):
                sliced_dataset = gpt2_dataset[slice:]['text']
            else:
                sliced_dataset = gpt2_dataset[slice:slice+batch_size]['text']

            try:
                gpt2(sliced_dataset)
            except Exception:
                pass

            activations_for_batch = cache['blocks.0.mlp.hook_post']
            output.append(activations_for_batch)
    # print(f"{output=}")
    t.save(t.cat(output, dim=0), f"activations_final.pt")

# create_dataset(write_after=100)

In [None]:
%pip install ipywidgets

In [None]:
import os
files = [fname for fname in os.listdir('./') if 'activations' in fname]
activation_tensors = [t.load(fname) for fname in files[:5]]
activation_tensors = t.cat(activation_tensors, dim=0)
print(activation_tensors.shape) # should be (10000, 1024, 768)

class Dataset(Dataset):
    def __init__(self, activation_tensors):
        self.activation_tensors = activation_tensors

    def __len__(self):
        return len(self.activation_tensors)

    def __getitem__(self, idx):
        return self.activation_tensors[idx]
    
train_dataloader = DataLoader(t.utils.data.TensorDataset(activation_tensors), batch_size=cfg.batch_size, shuffle=True)
for batch in train_dataloader:
    print(batch[0].shape) # should be (1, 1024, 768)
    break

## model

In [None]:
class SAE(nn.Module):
    def __init__(self, n_in, n_hidden):
        super(SAE, self).__init__()
        self.b_a = nn.Parameter(t.zeros(n_in))
        self.b_d = nn.Parameter(t.zeros(n_in))
        self.b_e = nn.Parameter(t.zeros(n_hidden))

        self.W_e = nn.Linear(n_in, n_hidden, bias=False)
        self.W_d = nn.Linear(n_hidden, n_in, bias=False)

    def forward(self, act):
        x = act + self.b_a
        x = self.W_e(x) + self.b_e
        x = F.relu(x)
        hidden_activations = x
        x = self.W_d(x) + self.b_d
        return x, hidden_activations


## train

In [None]:
model = SAE(n_in=cfg.n_in, n_hidden=cfg.n_hidden).to(device)
opt = t.optim.Adam(model.parameters(), lr=1e-3)
def train(model, opt, dataloader, epochs=100, l1_factor=0.1, wnb=True):
    if wnb:
        wandb.init(project='SAE')

    for epoch in tqdm(range(epochs)):
        for batch in tqdm(dataloader):
            batch = batch[0].to(device)
            logits, hidden_activations = model(batch)
            reconstruction_loss = F.mse_loss(logits, batch)
            sparsity_loss = hidden_activations.abs().mean()
            loss = reconstruction_loss + l1_factor * sparsity_loss
            if wnb:
                wandb.log({"reconstruction_loss": reconstruction_loss.item()})
                wandb.log({"sparsity_loss": sparsity_loss.item()})
                wandb.log({"loss": loss.item()})
            opt.zero_grad()
            loss.backward()
            opt.step()
        # TODO: validation loss
        # if wnb:
            # wandb.log({"epoch": epoch})

train(model, opt, train_dataloader)

In [None]:
t.save(model.state_dict(), 'model_intermediate.pt')

In [None]:
"a".upper()

In [None]:
input = "The heart of a blue whale is so" #.upper()

for i in range(10):
    logits = gpt2(input)
    most_probable_token = logits[0, -1].argmax().item()
    output = gpt2.to_single_str_token(most_probable_token)
    input += output
    print(input)

# logits = gpt2("THE HEART OF A BLUE WHALE IS A BL")
# print(logits.shape)
# most_probable_token = logits[0, -1].argmax().item()
# gpt2.to_single_str_token(most_probable_token)

In [None]:
fact_question_dataset = [
    {
        "fact": "The heart of a blue whale is so big, a human could swim through the arteries of its heart.",
        "question": "Is it true that a human could swim through the arteries of a blue whale's heart?"
    },
    {
        "fact": "Sloths can hold their breath for up to 40",
        "question": "Can sloths really hold their breath for up to 40 minutes?"
    },
    {
        "fact": "A snail can sleep for up to three years.",
        "question": "Can a snail sleep for up to three years?"
    },
    {
        "fact": "Penguins have knees.",
        "question": "Do penguins really have knees?"
    },
    {
        "fact": "The male emperor penguin protects the eggs from the cold.",
        "question": "Is it the male emperor penguin that protects the eggs from the cold?"
    },
    {
        "fact": "The largest land animal is the African elephant.",
        "question": "Is the African elephant the largest land animal?"
    },
    {
        "fact": "Kangaroos use their tails for balance.",
        "question": "Do kangaroos use their tails for balance?"
    },
    {
        "fact": "A group of flamingos is called a 'flamboyance.'",
        "question": "Is a group of flamingos called a 'flamboyance'?"
    },
    {
        "fact": "Rabbits and hares are different species.",
        "question": "Are rabbits and hares different species?"
    },
    {
        "fact": "Cows have four stomachs.",
        "question": "Do cows have four stomachs?"
    }
]


In [None]:
filter_mlp_out = lambda name: ("blocks.0.mlp.hook_post" == name)
gpt2.reset_hooks()
cache = {}
def forward_cache_hook(act, hook):
    with t.no_grad():
        logits, hidden_activations = model(act)
        cache[hook.name] = hidden_activations.detach()
        return logits

gpt2.add_hook(filter_mlp_out, forward_cache_hook, "fwd")

In [None]:
lower_caches = []
upper_caches = []
for pair in fact_question_dataset:
    lower = pair["fact"]
    upper = lower.upper()
    gpt2(lower)
    lower_caches.append(cache["blocks.0.mlp.hook_post"])
    gpt2(upper)
    upper_caches.append(cache["blocks.0.mlp.hook_post"])

In [None]:
lower_caches[0].shape

In [None]:
max_lower = t.stack([x.amax(dim=(0,1)) for x in lower_caches]).amax(0)
print(max_lower)
min_upper = t.stack([x.amin(dim=(0,1)) for x in upper_caches]).amin(0)
print(min_upper.amax())
neuron_idx = (min_upper - max_lower).argmax()
print(min_upper[neuron_idx])
print(max_lower[neuron_idx]) 

In [None]:
a = t.rand(3,5)
print(a)
print(a.topk(2, dim=1))

In [None]:
%pip install matplotlib

In [60]:
import matplotlib.pyplot as plt

frequency_neuron_counter = {}
flat_count = []
for l in lower_caches:
    v, i= l.topk(100, dim=2)
    flat_count.append(i.view(-1))

bincount = t.cat(flat_count, dim=0).bincount().cpu()
print(bincount.topk(100))



torch.return_types.topk(
values=tensor([111, 110, 110, 110, 106, 105, 104, 102, 102, 102, 102, 102, 102, 102,
        102, 102, 102, 102, 102, 102, 102, 102, 102, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
        100, 100, 100, 100, 100,  98,  97,  97,  97,  97,  94,  92,  91,  88,
         71,  63]),
indices=tensor([15663,  1391, 17581,  7686, 23529,  6271,   347,     8,     3,    11,
         4037,     2,     1,     6,    10,     0,    12,     4,    22,    13,
            9,     5,     7,    59,    29,    61,    57,    56,    27,    55,
           53,    25,    51,    49,    48,    23,    47,    45,    21,    43,
           44,    46,    20,    41,    24,    50,    52,    26,    54,    42,
    

In [59]:
frequency_neuron_counter = {}
flat_count = []
for l in upper_caches:
    v, i= l.topk(100, dim=2)
    flat_count.append(i.view(-1))
bincount = t.cat(flat_count, dim=0).bincount().cpu()
print(bincount.topk(100))

torch.return_types.topk(
values=tensor([166, 163, 161, 161, 160, 160, 160, 160, 160, 160, 160, 160, 160, 160,
        160, 160, 160, 159, 159, 159, 159, 159, 159, 159, 159, 159, 159, 159,
        159, 159, 159, 159, 159, 159, 159, 158, 158, 158, 158, 158, 158, 158,
        158, 157, 157, 157, 157, 157, 157, 157, 157, 157, 157, 157, 157, 157,
        157, 157, 157, 157, 157, 157, 157, 157, 157, 157, 157, 156, 155, 155,
        155, 155, 155, 155, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154,
        154, 152, 151, 151, 150, 150, 149, 148, 147, 145, 140, 134, 129, 124,
        116,  95]),
indices=tensor([15663, 17581,  1391,  7686,    12,     5,     9,     6,     4,     2,
           10,     3,     8,    11,     7,     1,     0,    28,    27,    25,
           13,    24,    23,    26,    14,    21, 23529,    22,    19,    18,
           17,    29,    16,    15,    20,    30,    37,    36,    35,    34,
           33,    32,    31,    60,    59,    57,    56,    55,    53,    54,
    

In [None]:





output = gpt2(fact_question_dataset[0]['fact']) 

most_probable_token = t.argmax(output[0, -1]).item()
gpt2.to_single_str_token(most_probable_token)


