# Loading and Analysing Pre-Trained Sparse Autoencoders

## Imports & Installs

In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

## Set Up

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda:6" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface.

In [None]:
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset,load_from_disk
from transformer_lens import HookedTransformer
from typing import Any, Generator, Iterator, Literal, cast
from sae_lens import SAE
from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)

from transformer_lens.HookedLlava import HookedLlava
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_path=""
processor = LlavaNextProcessor.from_pretrained(model_path)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
        model_path, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True,
)

vision_tower = vision_model.vision_tower.to("cuda:6")
multi_modal_projector = vision_model.multi_modal_projector.to("cuda:6")
hook_language_model = HookedLlava.from_pretrained(
        MODEL_NAME,
        hf_model=vision_model.language_model,
        device="cuda:6", 
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
        tokenizer=None,
        dtype=torch.float32,
        vision_tower=vision_tower,
        multi_modal_projector=multi_modal_projector,
        n_devices=2,
    )
sae = SAE.load_from_pretrained(
    path = "",
    device ="cuda:7"
)

In [None]:
from transformer_lens.utils import tokenize_and_concatenate
import transformer_lens.utils as utils
dataset = load_dataset(
    path = "",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = hook_language_model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

## Basic Analysis

Let's check some basic stats on this SAE in order to see how some basic functionality in the codebase works.

We'll calculate:
- L0 (the number of features that fire per activation)
- The cross entropy loss when the output of the SAE is used in place of the activations

### L0 Test and Reconstruction Test

In [None]:
sae.cfg

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:2]["tokens"]
    _, cache = hook_language_model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=lambda name: name == sae.cfg.hook_name)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()
    

Note that while the mean L0 is 64, it varies with the specific activation.

To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens.

In [None]:
from transformer_lens import utils
from functools import partial
torch.cuda.empty_cache()
# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", hook_language_model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    hook_language_model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    hook_language_model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [None]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
# utils.test_prompt(example_prompt, example_answer, hook_language_model, prepend_bos=True)

logits, cache = hook_language_model.run_with_cache(example_prompt, prepend_bos=True)
tokens = hook_language_model.to_tokens(example_prompt)
sae_out = sae(cache[sae.cfg.hook_name])


def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_name = sae.cfg.hook_name
# print(batch_tokens)
print(tokens)
print("Orig", hook_language_model(tokens, return_type="loss").item())
print(
    "reconstr",
    hook_language_model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    hook_language_model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_name, zero_abl_hook)],
    ).item(),
)


with hook_language_model.hooks(
    fwd_hooks=[
        (
            hook_name,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, hook_language_model, prepend_bos=True)

# Generating Feature Interfaces

Feature dashboards are an important part of SAE Evaluation. They work by:
- 1. Collecting feature activations over a larger number of examples.
- 2. Aggregating feature specific statistics (such as max activating examples).
- 3. Representing that information in a standardized way

In [None]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData
torch.cuda.empty_cache()
sae.eval()
test_feature_idx_gpt = list(range(10)) + [14057]
hook_name = sae.cfg.hook_name
feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_name,
    features=test_feature_idx_gpt,
    # batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)
torch.cuda.empty_cache()
sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=hook_language_model, # type: ignore
    tokens=token_dataset[:1]["tokens"],  # type: ignore
    cfg=feature_vis_config_gpt,
)

In [None]:
for feature in test_feature_idx_gpt:
    filename = f"{feature}_feature_vis_demo_gpt.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    # display_vis_inline(filename)

Now, since generating feature dashboards can be done once per sparse autoencoder, for pre-trained SAEs in the public domain, everyone can use the same dashboards. Neuronpedia hosts dashboards which we can load via the intergration.

In [None]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

# this function should open
neuronpedia_quick_list = get_neuronpedia_quick_list(
    sae=sae,
    features=test_feature_idx_gpt,

    name="A quick list we made",
)

if COLAB:
  # If you're on colab, click the link below
  print(neuronpedia_quick_list)