In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=[
        "expected_var_explained",
        "expected_l0",
        "config_overrides",
        "conversion_func",
    ],
    inplace=True,
)
df[
    df.release.str.contains("bench")
]  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.
# print(df.head())

# print all columns
print(df.columns)
# # print(df.model)
# # print(df.release)
# print(df['model']["sae_bench_pythia70m_sweep_topk_ctx128_0730"])
# print(df.saes_map["sae_bench_pythia70m_sweep_topk_ctx128_0730"])

In [None]:
sae_basic_info = {
    "pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_9": {
        "sae_config_dictionary_learning": {},
        "basic_eval_results": {"l0": 80, "frac_recovered": 0.99},
    }
}


custom_results_dict = {
    "custom_eval_config": {"dataset": "Bib", "n_inputs": 100},
    "custom_eval_results": {
        "pythia70m_sweep_topk_ctx128_0730/resid_post_layer_4/trainer_9": {
            "sparse_probing_k_1": 0.55
        }
    },
}

In [None]:
# Standard imports
# Imports for displaying vis in Colab / notebook

import plotly.express as px
import torch

PORT = 8000

torch.set_grad_enabled(False)

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"


print(f"Device: {device}")

In [None]:
from datasets import load_dataset
from sae_lens import SAE
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("pythia-70m-deduped", device=device)

In [None]:
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="sae_bench_pythia70m_sweep_topk_ctx128_0730",
    sae_id="blocks.4.hook_resid_post__trainer_10",
    device=device,
)
sae = sae.to(device=device)

In [None]:
print(cfg_dict)
print(sparsity)
print(sae)

In [None]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [None]:
print(device, sae.device)

In [None]:
print(token_dataset[0])

In [None]:

from functools import partial

sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

batch_size = 32
seq_len = sae.cfg.context_size
d_model = sae.cfg.d_in
layer = sae.cfg.hook_layer

all_acts_list_BLD = []


def activation_hook(resid_BLD: torch.Tensor, hook):
    all_acts_list_BLD.append(resid_BLD)
    return resid_BLD


hook_name = f"blocks.{layer}.hook_resid_post"
# model.add_hook(hook_name, temp_hook_fn)


batches = 10

with torch.no_grad():
    # activation store can give us tokens.

    for i in range(batches):
        batch_tokens = token_dataset[(i * batch_size) : ((i + 1) * batch_size)][
            "tokens"
        ]

        model.run_with_hooks(
            batch_tokens, return_type=None, fwd_hooks=[(hook_name, activation_hook)]
        )

acts_BLD = torch.cat(all_acts_list_BLD, dim=0)
print(acts_BLD.shape)

# Use the SAE
# feature_acts = sae.encode(cache[sae.cfg.hook_name])
# sae_out = sae.decode(feature_acts)

# # save some room
# del cache

# # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
# l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
# print("average l0", l0.mean().item())
# px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    print(cache[sae.cfg.hook_name].device, sae.device)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:


# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)