# SAE-Lens eval repro

The demo notebook was broken - here is a fixed version.

### Preamble

In [None]:
import json 
import os
import numpy as np
import torch
import plotly.express as px

from datasets import load_dataset
from functools import partial
from pathlib import Path
from transformer_lens import utils
from typing import Dict

from huggingface_hub import hf_hub_download
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
# from sae_lens.analysis.visualizer.data_fns import get_feature_data, FeatureData
from sae_vis.data_fetching_fns import get_feature_data, FeatureData

In [None]:
# lifting from https://github.com/jbloomAus/SAELens/blob/main/tutorials/evaluating_your_sae.ipynb

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)


In [None]:
# Download the SAE for layer 0 (the one with the Magikarps)

REPO_ID = "jbloom/GPT2-Small-SAEs"
layer = 0  # any layer from 0 - 11 works here
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"

path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

In [None]:
# We can then load the SAE, dataset and model using the session loader
model, sparse_autoencoders, activation_store = (
    LMSparseAutoencoderSessionloader.load_session_from_pretrained(path=path)
)

In [None]:
for i, sae in enumerate(sparse_autoencoders):
    hyp = sae.cfg
    print(
        f"{i}: Layer {hyp.hook_point_layer}, p_norm {hyp.lp_norm}, alpha {hyp.l1_coefficient}"
    )

In [None]:
# pick which sae you wnat to evaluate. Default is 0
sparse_autoencoder = list(sparse_autoencoders)[0]

In [None]:
#Test the Autoencoder

sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    batch_tokens = activation_store.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)

In [None]:
# Specific Capability Test
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)


def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_point = sparse_autoencoder.cfg.hook_point

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_point,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            hook_point,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
# Generating feature interfaces
vals, inds = torch.topk(feature_acts[0, -1].detach().cpu(), 10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
from sae_vis.data_config_classes import (
    # ActsHistogramConfig,
    # Column,
    # FeatureTablesConfig,
    SaeVisConfig)
import pickle

In [None]:
DASHBOARD_FOLDER = 'dashboards'


vocab_dict = model.tokenizer.vocab
vocab_dict = {
    v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset(
    "NeelNanda/c4-code-20k", split="train"
)  # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096 * 5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]


feature_vis_params = SaeVisConfig(
    hook_point=sparse_autoencoder.cfg.hook_point,
    minibatch_size_features=256,
    minibatch_size_tokens=64,
    features=feature_idx,
    verbose=True,
)

pickle_filepath = Path(os.getcwd()) / "feature_data.pickle"

if not os.path.exists(pickle_filepath):

    feature_data: Dict[int, FeatureData] = get_feature_data(
        encoder=sparse_autoencoder,
        model=model,
        tokens=tokens,
        cfg=feature_vis_params
    )

    feature_data.model = model

    with open(pickle_filepath, "wb") as f:
        pickle.dump(feature_data, f)

else:
    with open(pickle_filepath, "rb") as f:
        feature_data = pickle.load(f)

for test_idx in feature_data.feature_data_dict:
    feature_data.save_feature_centric_vis(
        f"{DASHBOARD_FOLDER}/data_{test_idx:04}.html",
        feature_idx=test_idx,
    )

In [None]:
feature_data