# Evaluating your SAE

## Set Up

In [1]:
import os
import sys
import torch
import json
import plotly.express as px
from transformer_lens import utils
from datasets import load_dataset
from typing import Dict
from pathlib import Path

from functools import partial

sys.path.append("..")

from sae.language.utils import LMSparseAutoencoderSessionloader
from sae.language.sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]


<torch.autograd.grad_mode.set_grad_enabled at 0x22ebdce2e90>

# Load your Autoencoder



In [2]:
# Start by downloading them from huggingface
# from huggingface_hub import hf_hub_download

# REPO_ID = "jbloom/GPT2-Small-SAEs"


# layer = 2  # any layer from 0 - 11 works here
# FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"

# # this is great because if you've already downloaded the SAE it won't download it twice!
# path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, cache_dir="F:/huggingfacecache/hub")


path = r"F:\ViT-Prisma_fork\data\sae_checkpoints\d581udr4\final_sae_group_gpt2-small_blocks.2.hook_resid_pre_49152.pt"

In [3]:
# We can then load the SAE, dataset and model using the session loader
print(path)


model, sparse_autoencoders, activation_store = (
   LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)
)

F:\ViT-Prisma_fork\data\sae_checkpoints\d581udr4\final_sae_group_gpt2-small_blocks.2.hook_resid_pre_49152.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Dataset is not tokenized! Updating config.
Run name: 49152-L1-8e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
Using Ghost Grads.
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 49152-L1-8e-05-LR-0.0004-Tokens-3.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 73242
Total wandb updates: 732
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
Using Ghost Grads.
We will reset the sparsity calculation 73 times.
Number tokens in sparsity calculation window: 4.10e+06


In [6]:
for i, sae in enumerate(sparse_autoencoders):
    hyp = sae.cfg
    print(
        f"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )

0: Layer 2, p_norm 1, alpha 8e-05


In [7]:
# pick which sae you wnat to evaluate. Default is 0
sparse_autoencoder = sparse_autoencoders.autoencoders[0]

## Test the Autoencoder

### L0 Test and Reconstruction Test

In [8]:
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens = activation_store.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

average l0 28.237205505371094


In [9]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)

Orig 3.5783743858337402
reconstr 6.448983192443848
Zero 11.759973526000977


## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [10]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)


def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_point = sparse_autoencoder.cfg.hook_point

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_point,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            hook_point,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.19 Prob: 69.93% Token: | Mary|
Top 1th token. Logit: 15.82 Prob:  6.49% Token: | them|
Top 2th token. Logit: 15.48 Prob:  4.66% Token: | the|
Top 3th token. Logit: 14.93 Prob:  2.66% Token: | his|
Top 4th token. Logit: 14.86 Prob:  2.49% Token: | John|
Top 5th token. Logit: 14.12 Prob:  1.19% Token: | her|
Top 6th token. Logit: 13.99 Prob:  1.04% Token: | their|
Top 7th token. Logit: 13.70 Prob:  0.78% Token: | a|
Top 8th token. Logit: 13.53 Prob:  0.66% Token: | him|
Top 9th token. Logit: 13.39 Prob:  0.57% Token: | Mrs|


Orig 3.979093551635742
reconstr 4.050536155700684
Zero 7.470455169677734
Tokenized prompt: ['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' John', ' gave', ' the', ' bag', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.21 Prob: 75.11% Token: | Mary|
Top 1th token. Logit: 15.30 Prob:  4.10% Token: | the|
Top 2th token. Logit: 15.23 Prob:  3.83% Token: | them|
Top 3th token. Logit: 14.79 Prob:  2.48% Token: | his|
Top 4th token. Logit: 14.36 Prob:  1.61% Token: | John|
Top 5th token. Logit: 14.19 Prob:  1.36% Token: | her|
Top 6th token. Logit: 14.03 Prob:  1.15% Token: | their|
Top 7th token. Logit: 13.68 Prob:  0.81% Token: | a|
Top 8th token. Logit: 13.55 Prob:  0.71% Token: | Mrs|
Top 9th token. Logit: 12.91 Prob:  0.38% Token: | him|


# Generating Feature Interfaces

In [11]:
vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [14]:
vocab_dict = model.tokenizer.vocab
vocab_dict = {
    v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset(
    "NeelNanda/c4-code-20k", split="train"
)  # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096 * 5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,
    hook_point_head_index=0,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k=3,
    buffer=(5, 5),
    n_groups=10,
    first_group_size=20,
    other_groups_size=5,
    verbose=True,
)


for test_idx in list(inds.flatten().cpu().numpy()):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

Storing model activations:   0%|          | 0/40 [00:00<?, ?it/s]


RuntimeError: output with shape [] doesn't match the broadcast shape [10]

This will produce a number of html files which each contain a dashboard showing feature activation on the sample data. It currently doesn't process that much data so it isn't that useful. 