# Main Features Demo

Start by pip installing the repo.

In [None]:
# cd sae_auto_interp && pip install -e .

# 1️⃣ - Loading your own autoencoders.

We use the `nnsight` library to attach autoencoders to the module tree. 

At the time of writing (8/8/24), this feature isn't yet available on the main version of `nnsight`. Please install the `0.3` branch.

```
pip install git+https://github.com/ndif-team/nnsight.git@0.3
```

For this demo, we'll load, cache, and evaluate some layer zero features from the recent OpenAI topk autoencoders.

In [None]:
from functools import partial

import torch
from nnsight import LanguageModel

from sae_auto_interp.autoencoders.wrapper import AutoencoderLatents
from sae_auto_interp.autoencoders.OpenAI import Autoencoder

path = "/home/xuzhen/switch_sae/dictionaries/dict_class:_lb.MoEAutoEncoder'>_activation_dim:768_dict_size:24576_auxk_alpha:0.03125_decay_start:560_steps:700_seed:0_device:cuda:3_layer:8_lm_name:xuzhenswitch_saegpt2_wandb_name:MoEAutoEncoder_k:32_experts:64_e:2_heaviside:False/6.pt" # Change this line to your weights location.
state_dict = torch.load(path)
ae = Autoencoder.from_state_dict(state_dict=state_dict)
ae.to("cuda:0")

model = LanguageModel("/home/xuzhen/switch_sae/gpt2", device_map="auto", dispatch=True)

We provide a helpful wrapper for collecting autoencoder latents. The wrapper is a `torch.nn.Module` which calls a given `forward` method at every forward pass. We'll use `partial` here so we don't run into late binding issues. 

If we use a lambda like `lambda x: ae.encode(x)[0]`, our wrappers will get only get a refrence to the last autoencoder's `encode` method in the loop.

In [None]:
def _forward(ae, x):
    latents, _ = ae.encode(x)
    return latents

# We can simply add the new module as an attribute to an existing
# submodule on GPT-2's module tree.
submodule = model.transformer.h[0]
submodule.ae = AutoencoderLatents(
    ae, 
    partial(_forward, ae),
    width=131_072
)

Next, we'll use `nnsight`'s `edit` context to set default interventions on the model's forward pass. 

Check out the [official demo](https://github.com/ndif-team/nnsight/blob/main/NNsight_v0_2.ipynb) to learn more about `nnsight` (which will be updated to 0.3 soon).

As a quick refresher, `nnsight` allows users to execute PyTorch models, with interventions, lazily. A context manager collects operations, then compiles and executes them on completion. The `.edit` context defines default nodes in the intervention graph to be compiled on execution of the real model. 

In [None]:
with model.edit(" "):
    acts = submodule.output[0]
    submodule.ae(acts, hook=True)

Awesome! Now collecting latents is as simple as saving the output of the submodule within the trace. This is uniquely helpful because (a) we can just handle references to submodules and access their `.ae` property which (b) removes the complexity of having to store a dictionary of submodules and their respective autoencoders, then passing the submodule's activations through the autoencoder every forward pass.

In [None]:
with model.trace("hello, my name is"):
    latents = submodule.ae.output.save()

latents

The process above is a quite a bit of boilerplate, so we provide some starter code within the `.autoencoders` module. See the available options in the `__init__.py` file.

# 2️⃣ - Caching Activations

Now that we have an edited model, lets cache activations for the first one hundred features in the autoencoder across 100k tokens. Ideally, you'll want to cache on as many tokens as necessary to get a wide distribution of activations for your autoencoder's rarer features.

Let's define a couple of constants for our cache and load tokens. Again, we provide utils for loading a `torch.utils.data.Dataset` of tokens, but feel free to load and tokenize however you want. Note that our tokenizer appends padding to the start of every sequence in the batch.

In [None]:
from sae_auto_interp.features import FeatureCache
from sae_auto_interp.utils import load_tokenized_data

CTX_LEN = 64
BATCH_SIZE = 32
N_TOKENS = 500_000
N_SPLITS = 2

tokens = load_tokenized_data(
    CTX_LEN,
    model.tokenizer,
    "kh4dien/fineweb-100m-sample",
    "train[:15%]",
)

The cache accepts two dictionaries. 

`submodule_dict` is a `Dict[str, nnsight.Envoy]` which is iterated through during caching. 

`module_filter` is an optional filter for which we mask feature_ids found from caching. Note that this process is a slower, especially for larger numbers of tokens. However, it's very helpful for conserving CPU memory.

In [None]:
module_path = submodule._module_path

submodule_dict = {module_path : submodule}
module_filter = {module_path : torch.arange(100).to("cuda:0")}

cache = FeatureCache(
    model, 
    submodule_dict, 
    batch_size=BATCH_SIZE, 
    filters=module_filter
)

cache.run(N_TOKENS, tokens)

Raw features are saved as `safetensors` with the structure:

```python
{
    "location" : torch.Tensor["n_activations", 3],
    "activations" : torch.Tensor["n_activations"],
}
```

Where each row of locations points to an activation, with the data `[batch_idx, seq_pos, feature_id]`. We also provide a splits parameter to save splits of the features into different `safetensors`.

In [None]:
raw_dir = "raw_features/gpt2_128k" # Change this line to your save location.
cache.save_splits(
    n_splits=N_SPLITS,
    save_dir=raw_dir,
)

# 3️⃣ - Loading Activations

We provide a data loader for reconstructing features from their locations and activations. 

The loader requires a `FeatureConfig` which details how features were saved and how to reconstruct examples. 

The `ExperimentConfig` configures how train and test examples are sampled for explanation and scoring.

In [None]:
from sae_auto_interp.features import FeatureDataset, pool_max_activation_windows, sample
from sae_auto_interp.config import FeatureConfig, ExperimentConfig

cfg = FeatureConfig(
    width = 131_072,
    min_examples = 200,
    max_examples = 10_000,
    example_cfg_len = 20,
    n_splits = 2
)

sample_cfg = ExperimentConfig()

dataset = FeatureDataset(
    raw_dir=raw_dir,
    cfg=cfg,
)

The `.load` method of dataset accepts functions to reconstruct and sample activations. 

In [None]:
constructor=partial(
    pool_max_activation_windows,
    tokens=tokens,
    ctx_len=sample_cfg.example_ctx_len,
    max_examples=cfg.max_examples,
)

sampler = partial(
    sample,
    cfg=sample_cfg
)

Let's load a batch of records! The `.load` method is an iterator that just returns all records in a split.

In [None]:
for records in dataset.load(constructor=constructor, sampler=sampler):
    break

record = records[0]

The display method in `.utils` just renders examples as html with their activating tokens highlighted.

In [None]:
from sae_auto_interp.utils import display

print(record.feature)
display(record, model.tokenizer, n=5)

# 3️⃣ - Explaining Activations

We define several clients for querying completion APIs such as vLLM and OpenRouter. For this example, we'll just use the OpenRouter client with `gpt-4o-mini`.

In [None]:
from sae_auto_interp.clients import VLLM

client = VLLM('TheBloke/Mistral-7B-Instruct-v0.2-GPTQ')


Just load an explainer and pass the client, a tokenizer, and generation configs as optional keyword arguments. The explainer outputs an `ExplainerResult` tuple.

In [None]:
from sae_auto_interp.explainers import SimpleExplainer

explainer = SimpleExplainer(
    client,
    model.tokenizer,
    max_new_tokens=50,
    temperature=0.0
)

explainer_result = await explainer(record)

print(explainer_result.explanation)

# 4️⃣ - Explaining Activations

Similarly, we can score explanations by loading a scorer and passing an feature record. The record should be updated to contain the `.explanation` attribute. 

In this example, we use the `RecallScorer` which requires random, non-activating examples to measure precision. For simplicity, we didn't sample those earlier so we'll just set those to train examples.

In [None]:
from sae_auto_interp.scorers import RecallScorer

scorer = RecallScorer(
    client,
    model.tokenizer,
    max_tokens=25,
    temperature=0.0,
    batch_size=10,
)

record.explanation = explainer_result.explanation
record.random_examples = record.train

score = await scorer(record)

Awesome! We got a score. The `.score` attribute contains a list of `ClassifierOutput`s. For each `ClassifierOutput`, we have the following attributes:

- `distance` : The quantile of the sample.
- `ground_truth` : Whether the sample actually activated or not.
- `prediction` : The model's prediction for whether the example activated. 
- `highlighted` : Whether the example was "highlighted" or not. Only True for the `FuzzScorer`.