In [None]:
%pip install sae-lens transformer-lens

Collecting sae-lens
  Downloading sae_lens-3.21.1-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.6.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dot

In [None]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [None]:
# define the model to work with
MODEL = 'MISTRAL' # GEMMA, GPT2

if MODEL == 'GEMMA':
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    FINETUNE_PATH = None
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'post'
    layer_num = 6
elif MODEL == 'GPT2':
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'pre'
    layer_num = 6
elif MODEL == 'MISTRAL':
    RELEASE = 'mistral-7b-res-wg'
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B' #DeepMount00/Mistral-Ita-7b
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    hook_part = 'pre'
    layer_num = 8

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

## Computing the normalization scalar

As per [Kissane et al.](https://www.lesswrong.com/posts/fmwk6qxrpW8d4jvbd/saes-usually-transfer-between-base-and-chat-models), our goal is to find a scalar $S$ such that $\mathbb{E}_{x \in X} \left[ S \|x\|_2 \right] = \sqrt{d_{model}}$, where $X$ is our activations dataset for the SAE. This will allow us to filter out the activations with a norm above the $\sqrt{d_{model}}$ threshold, after we scale the activations by multiplying with $S$.

In [None]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)
base_model.cfg

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]



Loaded pretrained model mistral-7b into HookedTransformer


HookedTransformerConfig:
{'act_fn': 'silu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 11.313708498984761,
 'attn_scores_soft_cap': -1.0,
 'attn_types': ['local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local',
                'local'],
 'checkpoint_index': None,
 'checkpoint_

In [None]:
from math import sqrt
d_model = base_model.cfg.d_model

THRESHOLD = sqrt(d_model)
print(f'Outlier threshold = {THRESHOLD}')

Outlier threshold = 64.0


In [None]:
# import the required libraries
from sae_lens import SAE
sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)

mistral_7b_layer_8/cfg.json:   0%|          | 0.00/430 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

In [None]:
from sae_lens import ActivationsStore

total_batches = 50
batch_size_prompts = 5

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts
total_tokens = total_batches * batch_size_tokens

batch_size_tokens, total_tokens

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



(1280, 64000)

### Computing the average activation norm

In [None]:
from tqdm import tqdm

def get_average_norm(model=base_model, activation_store=activation_store, total_batches=total_batches, sae_id=sae_id):
    total_norm = 0.0

    for k in tqdm(range(total_batches)):
        # Get a batch of tokens from the dataset
        tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

        # Run the model and store the activations
        _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                             names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

        # Get the activations from the cache at the sae_id
        activations = cache[sae_id].float()

        # Compute the Frobenius norm of the activations
        # Activation shape: [N_BATCH, N_CONTEXT, D_MODEL]
        batch_norm = torch.norm(activations, dim=-1)  # [N_BATCH, N_CONTEXT]
        batch_norm_mean = batch_norm.mean().item()  # Get the mean norm over all tokens

        # Keep track of total norm
        total_norm += batch_norm_mean

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del cache, activations
        clear_cache()

    # Compute the average norm over all batches and tokens
    average_norm = total_norm / total_batches
    return average_norm

In [None]:
norm_average = get_average_norm()
norm_average

100%|██████████| 50/50 [00:16<00:00,  3.01it/s]


4.513891072273254

### Solving for the scalar

Solving the equation above, we get that $S = \frac{\sqrt{d_{model}}}{\mathbb{E}_{x \in X} \left[ \|x\|_2 \right]}$, where the denominator has been just computed in the above cell.

In [None]:
NORM_SCALAR = sqrt(d_model) / norm_average
NORM_SCALAR

14.178454680291779

## Filtering out outlier activations

In [None]:
def is_act_outlier(act_tensor, threshold):
    """
    Expects act_tensor of shape [*, D_MODEL]
    Returns a boolean tensor of shape [*], where for each batch position we report whether the corresponding activation
    exceeds the outlier threshold (defined as sqrt(D_MODEL) by default)
    """

    scaled_act = NORM_SCALAR * act_tensor
    scaled_act_norms = torch.norm(scaled_act, dim=-1)

    return scaled_act_norms > threshold

In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=base_model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        prompt=prompt,
        pos=pos,
        label=label,
    ))

def get_outliers(threshold, model=base_model, activation_store=activation_store, total_batches=total_batches, sae_id=sae_id,
                 compute_norms=True): # threshold should be a value for **scaled norms**
    total_outliers_count = 0.0
    total_tokens = 0
    all_outlier_token_dfs = []

    for k in tqdm(range(total_batches)):
        # Get a batch of tokens from the dataset
        tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
        token_df = make_token_df(tokens)

        # Run the model and store the activations
        _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1,
                                             names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

        # Get the activations from the cache at the sae_id
        activations = cache[sae_id].float()

        # Determine outliers
        outlier_mask = is_act_outlier(activations, threshold=threshold) # [N_BATCH, N_CONTEXT]
        outliers_count = outlier_mask.float().sum().item()

        # Convert mask to boolean numpy array
        outlier_mask_np = outlier_mask.cpu().numpy().flatten().astype(bool)

        # Select outlier tokens using the boolean mask
        outlier_tokens = token_df.iloc[outlier_mask_np].copy()

        if compute_norms:
            # Compute norms for outlier activations
            norms = torch.norm(activations[outlier_mask], dim=-1).flatten().cpu().numpy()
            outlier_tokens['norm'] = norms

        all_outlier_token_dfs.append(outlier_tokens)

        total_outliers_count += outliers_count
        total_tokens += tokens.shape[0] * tokens.shape[1]

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del cache, activations
        torch.cuda.empty_cache()

    # Concatenate all outlier DataFrames
    all_outlier_token_dfs = pd.concat(all_outlier_token_dfs, ignore_index=True)
    if compute_norms:
        all_outlier_token_dfs = all_outlier_token_dfs.sort_values(by='norm', ascending=False)

    # Calculate the fraction of outliers
    outlier_fraction = total_outliers_count / total_tokens

    return outlier_fraction, all_outlier_token_dfs

In [None]:
threshold_multiplier = 2
base_threshold = THRESHOLD

# threshold = threshold_multiplier * base_threshold
threshold = 200 * NORM_SCALAR

threshold

2835.690936058356

In [None]:
NORM_SCALAR

14.178454680291779

In [None]:
outlier_fraction, outlier_tokens = get_outliers(threshold, total_batches=100)

100%|██████████| 100/100 [00:20<00:00,  4.91it/s]


In [None]:
outlier_fraction

0.0069140625

In [None]:
outlier_tokens

Unnamed: 0,str_tokens,unique_token,context,prompt,pos,label,norm
561,\n,\n/20,1999].|\n|\nWenote,4,20,4/20,384.656311
464,\n,\n/13,superintervening|\n| neglig,1,13,1/13,384.551392
385,\n,\n/26,pt]{minimal}|\n| \,4,26,4/26,384.440674
550,\n,\n/35,~810$\|\n|\n--------------------------------,3,35,3/35,384.435150
535,\n,\n/11,label{DefZ}|\n|z=z,4,11,4/11,384.352081
...,...,...,...,...,...,...,...
686,as,as/0,|as|othercaredelivery,1,0,1/0,284.030548
609,as,as/0,|as|$\pi+,0,0,0/0,284.030548
176,\n,\n/7,EnumValueDescriptor){|\n| return(,4,7,4/7,283.381012
502,\n,\n/3,bones and|\n|\ninability,4,3,4/3,282.907623


In [None]:
import plotly.express as px

def plot_norm_histogram(outlier_df):
    # Check if the 'norm' column exists in the dataframe
    if 'norm' in outlier_df.columns:
        fig = px.histogram(outlier_df, x='norm', title='Histogram of Outlier Norms',
                           labels={'norm': 'Outlier norm Values'}, histnorm='probability', nbins=50)
        fig.show()

plot_norm_histogram(outlier_tokens)

In [None]:
# Add a new column 'is_bos', 1 if str_tokens is '<bos>', else 0
outlier_tokens['is_bos'] = outlier_tokens['str_tokens'].apply(lambda x: 1 if x == '<bos>' else 0)
outlier_tokens

Unnamed: 0,str_tokens,unique_token,context,prompt,pos,label,norm,is_bos
561,\n,\n/20,1999].|\n|\nWenote,4,20,4/20,384.656311,0
464,\n,\n/13,superintervening|\n| neglig,1,13,1/13,384.551392,0
385,\n,\n/26,pt]{minimal}|\n| \,4,26,4/26,384.440674,0
550,\n,\n/35,~810$\|\n|\n--------------------------------,3,35,3/35,384.435150,0
535,\n,\n/11,label{DefZ}|\n|z=z,4,11,4/11,384.352081,0
...,...,...,...,...,...,...,...,...
686,as,as/0,|as|othercaredelivery,1,0,1/0,284.030548,0
609,as,as/0,|as|$\pi+,0,0,0/0,284.030548,0
176,\n,\n/7,EnumValueDescriptor){|\n| return(,4,7,4/7,283.381012,0
502,\n,\n/3,bones and|\n|\ninability,4,3,4/3,282.907623,0


In [None]:
# Create a histogram using Plotly for the 'is_bos' column
fig = px.histogram(outlier_tokens, x='is_bos', nbins=2, title='Histogram of is_bos Column',
                   labels={'is_bos': 'is_bos'}, histnorm='probability')
fig.update_layout(
    xaxis_title='is_bos',
    yaxis_title='Frequency',
    bargap=0.1
)

# Display the plot
fig.show()

In [None]:
NORM_SCALAR, threshold_multiplier, base_threshold

(14.178454680291779, 2, 64.0)