In [44]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch

torch.set_grad_enabled(False);

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [46]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device = device)

Loaded pretrained model gpt2-small into HookedTransformer


In [47]:
# ckpt_dir = "checkpoints/d543gzxy/final_204800000/"
# W&B: https://wandb.ai/shehper/gpt2-small-attn-4-sae/runs/pumu7rz3?nw=nwusershehper

ckpt_dir = "checkpoints/5eu9598y/final_409600000/"
# # W&B: https://wandb.ai/shehper/gpt2-small-attn-5-sae/runs/s4om7ilc?nw=nwusershehper

device = "cuda"
sae = SAE.load_from_pretrained(path=ckpt_dir,
                                    device=device)
orig_architecure = sae.cfg.architecture

sae.W_dec = sae.get_W_dec()
sae.b_dec = sae.get_b_dec()
sae.W_enc = sae.get_W_enc()
sae.b_enc = sae.get_b_enc()

dec_norms = sae.W_dec.norm(dim=-1)
sae.W_enc *= dec_norms
sae.b_enc *= dec_norms
sae.W_dec /= dec_norms[:, None]

In [48]:
from sae_lens import ActivationsStore

activations_store = ActivationsStore.from_sae(
    model = model,
    sae = sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
)



In [49]:
from tqdm import tqdm 

def get_tokens(
    activations_store: ActivationsStore,
    n_batches_to_sample_from: int = 4096 * 6,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 128, 128)

  0%|          | 0/128 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 128/128 [00:02<00:00, 47.63it/s]


In [50]:
act_scale = 0
num_batches = 16
for b in range(num_batches):
    # activation store can give us tokens.
    batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    act_scale += cache[sae.cfg.hook_name].flatten(start_dim=-2, end_dim=-1).norm(dim=-1).mean()
    del batch_tokens, cache; torch.cuda.empty_cache()
act_scale /= num_batches

In [51]:
import numpy as np
scaling_factor = np.sqrt(768)/act_scale
scaling_factor

tensor(3.6288, device='cuda:0')

In [52]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)

orig, reconst, zero = 0, 0, 0
for b in range(num_batches):
    # activation store can give us tokens.
    batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    activations = cache[sae.cfg.hook_name] * scaling_factor
    
    # print(activations.flatten(start_dim=-2, end_dim=-1).norm(dim=-1).mean()**2)
    del cache; torch.cuda.empty_cache()
    
    encode_out, _ = sae.encode_fn(activations)
    sae_out = sae.decode_fn(encode_out) / scaling_factor

    del encode_out, activations; torch.cuda.empty_cache()
    
    orig += model(batch_tokens, return_type="loss").item()
    
    reconst += model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
        ).item()
    
    zero += model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
        ).item()

    del batch_tokens, sae_out; torch.cuda.empty_cache()

In [53]:
orig/num_batches, reconst/num_batches, zero/num_batches

(3.066639795899391, 3.083420529961586, 3.1721465438604355)

In [54]:
(zero - reconst)/(zero - orig)

0.8409510824047891

In [55]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def single_head_reconstr_hook(activation, hook, sae_out, head_id):
    new_activation = activation
    new_activation[:, :, head_id] = sae_out[:, :, head_id]
    return new_activation

def single_head_zero_abl_hook(activation, hook, head_id):
    new_activation = activation
    new_activation[:, :, head_id] = torch.zeros_like(activation[:, :, head_id, :])
    return new_activation

orig = 0
n_heads = model.cfg.n_heads # 12 for gpt2-small
reconst_per_head = {i: 0 for i in range(n_heads)}
zero_per_head = {i: 0 for i in range(n_heads)}

for head_id in range(n_heads):
    for b in range(num_batches):
        # activation store can give us tokens.
        batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
        activations = cache[sae.cfg.hook_name] * scaling_factor
        
        # print(activations.flatten(start_dim=-2, end_dim=-1).norm(dim=-1).mean()**2)
        del cache; torch.cuda.empty_cache()
        
        encode_out, _ = sae.encode_fn(activations)
        sae_out = sae.decode_fn(encode_out) / scaling_factor

        del encode_out, activations; torch.cuda.empty_cache()
        
        orig += model(batch_tokens, return_type="loss").item()
        
        reconst_per_head[head_id] += model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (
                    sae.cfg.hook_name,
                    partial(single_head_reconstr_hook, sae_out=sae_out, head_id=head_id),
                )
            ],
            return_type="loss",
            ).item()
        
        zero_per_head[head_id] += model.run_with_hooks(
            batch_tokens,
            return_type="loss",
            fwd_hooks=[
                (
                    sae.cfg.hook_name,
                    partial(single_head_zero_abl_hook, head_id=head_id),
                )
            ],
            ).item()

        del batch_tokens, sae_out; torch.cuda.empty_cache()

orig_loss = orig/(num_batches * n_heads)
reconst_per_head = {k: v /num_batches for k, v in reconst_per_head.items()}
zero_per_head = {k:v/num_batches for k, v in zero_per_head.items()}

In [56]:
ce_score_per_head = {}
for ((k, zero_), (kp, reconst_)) in zip(zero_per_head.items(), reconst_per_head.items()):
    ce_score_per_head[k] = (zero_ - reconst_)/(zero_ - orig_loss)
    print(k, f"{ce_score_per_head[k]:.4f}")

0 0.8851
1 0.9931
2 0.7398
3 0.5608
4 0.6781
5 0.9070
6 0.8570
7 0.7490
8 0.8904
9 0.8608
10 0.7694
11 0.2328


In [57]:
torch.cuda.empty_cache()

In [58]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

l0 = []
n_heads = 12
l0_heads = {i:[] for i in range(n_heads)}
for b in range(num_batches):
    # activation store can give us tokens.
    batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    activations = cache[sae.cfg.hook_name] * scaling_factor

    # Use the SAE
    feature_acts, _ = sae.encode_fn(activations)
    sae_out = sae.decode_fn(feature_acts) / scaling_factor

    #save some room
    del batch_tokens, sae_out, activations, cache; torch.cuda.empty_cache()

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0_batch = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0_batch.mean().item())
    #px.histogram(l0.flatten().cpu().numpy()).show()

    for head_id in range(n_heads):
        l0_head_batch = (feature_acts[:, 1:, head_id * 2048 : (head_id + 1) * 2048] > 0).float().sum(-1).detach()
        l0_heads[head_id].append(l0_head_batch)

    del feature_acts; torch.cuda.empty_cache()
    l0.append(l0_batch)

l0 = torch.cat(l0, dim=0)

average l0 276.4577331542969
average l0 270.0892028808594
average l0 270.93438720703125
average l0 271.114013671875
average l0 279.5993347167969
average l0 273.111572265625
average l0 286.69439697265625
average l0 294.6562805175781
average l0 266.6430969238281
average l0 268.17413330078125
average l0 299.94207763671875
average l0 293.4535827636719
average l0 277.634521484375
average l0 299.9217834472656
average l0 272.2970275878906
average l0 277.5822448730469


In [59]:
l0.shape

torch.Size([128, 1023])

In [60]:
for head_id in range(n_heads):
    l0_heads[head_id] = torch.cat(l0_heads[head_id], dim=0)

In [61]:
for head_id in range(n_heads):
    print(f"head = {head_id}, l0: {l0_heads[head_id].mean().item():.4f}")

head = 0, l0: 33.9917
head = 1, l0: 13.2530
head = 2, l0: 23.3034
head = 3, l0: 18.0465
head = 4, l0: 13.1367
head = 5, l0: 48.7831
head = 6, l0: 8.2849
head = 7, l0: 9.8118
head = 8, l0: 23.8324
head = 9, l0: 41.0594
head = 10, l0: 32.0946
head = 11, l0: 14.2964


### Layer 4 result

In [43]:
# layer 4
# import plotly.graph_objects as go

# # Sample data
# x = [l0_heads[head_id].mean().item() for head_id in range(n_heads)]
# y = [ce_score_per_head[head_id] for head_id in range(n_heads)]
# z = [str(head_id) for head_id in range(n_heads)]

# # Create scatter plot with labels
# fig = go.Figure(data=go.Scatter(x=x, y=y, mode='markers+text', text=z, textposition='top center'))

# # Add title and show plot
# fig.update_layout(title='Scatter Plot with Labels using plotly.graph_objects')
# fig.show()

### Layer 5 result

In [62]:
import plotly.graph_objects as go

# Sample data
x = [l0_heads[head_id].mean().item() for head_id in range(n_heads)]
y = [ce_score_per_head[head_id] for head_id in range(n_heads)]
z = [str(head_id) for head_id in range(n_heads)]

# Create scatter plot with labels
fig = go.Figure(data=go.Scatter(x=x, y=y, mode='markers+text', text=z, textposition='top center'))

# Add title and show plot
fig.update_layout(title='Scatter Plot with Labels using plotly.graph_objects')
fig.show()