In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch

torch.set_grad_enabled(False);

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device = device)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
ckpt_dir = "checkpoints/d543gzxy/final_204800000/"
# W&B: https://wandb.ai/shehper/gpt2-small-attn-4-sae/runs/pumu7rz3?nw=nwusershehper

# ckpt_dir = "checkpoints/5eu9598y/final_409600000/"
# # W&B: https://wandb.ai/shehper/gpt2-small-attn-5-sae/runs/s4om7ilc?nw=nwusershehper

device = "cuda"
sae = SAE.load_from_pretrained(path=ckpt_dir,
                                    device=device)
orig_architecure = sae.cfg.architecture

sae.W_dec = sae.get_W_dec()
sae.b_dec = sae.get_b_dec()
sae.W_enc = sae.get_W_enc()
sae.b_enc = sae.get_b_enc()

dec_norms = sae.W_dec.norm(dim=-1)
sae.W_enc *= dec_norms
sae.b_enc *= dec_norms
sae.W_dec /= dec_norms[:, None]

In [5]:
from sae_lens import ActivationsStore

activations_store = ActivationsStore.from_sae(
    model = model,
    sae = sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
)





In [6]:
from tqdm import tqdm 

def get_tokens(
    activations_store: ActivationsStore,
    n_batches_to_sample_from: int = 4096 * 6,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 128, 128)

  0%|          | 0/128 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 128/128 [00:02<00:00, 43.64it/s]


In [7]:
act_scale = 0
num_batches = 16
for b in range(num_batches):
    # activation store can give us tokens.
    batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    act_scale += cache[sae.cfg.hook_name].flatten(start_dim=-2, end_dim=-1).norm(dim=-1).mean()
    del batch_tokens, cache; torch.cuda.empty_cache()
act_scale /= num_batches

In [8]:
import numpy as np
scaling_factor = np.sqrt(768)/act_scale
scaling_factor

tensor(3.8212, device='cuda:0')

In [9]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)

orig, reconst, zero = 0, 0, 0
for b in range(num_batches):
    # activation store can give us tokens.
    batch_tokens = token_dataset[b * 8: (b + 1) * 8].clone()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    activations = cache[sae.cfg.hook_name] * scaling_factor
    
    # print(activations.flatten(start_dim=-2, end_dim=-1).norm(dim=-1).mean()**2)
    del cache; torch.cuda.empty_cache()
    
    encode_out, _ = sae.encode_fn(activations)
    sae_out = sae.decode_fn(encode_out) / scaling_factor

    del encode_out, activations; torch.cuda.empty_cache()
    
    orig += model(batch_tokens, return_type="loss").item()
    
    reconst += model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
        ).item()
    
    zero += model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
        ).item()

    del batch_tokens, sae_out; torch.cuda.empty_cache()

In [18]:
orig/num_batches, reconst/num_batches, zero/num_batches

(3.0895193219184875, 3.0992945432662964, 3.1744395196437836)

In [None]:
# TODO: perhaps the next thing to do is to check CE loss score by splicing in individual heads.  