In [16]:
# try:
#     import google.colab # type: ignore
#     from google.colab import output
#     COLAB = True
#     %pip install sae-lens transformer-lens
# except:
#     COLAB = False
from IPython import get_ipython # type: ignore
ipython = get_ipython(); assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from jaxtyping import Float, Int
from torch import Tensor
from pathlib import Path
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial

torch.set_grad_enabled(False)

# Device setup
GPU_TO_USE = 2

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device: cuda:2


## Config

In [11]:
BASE_MODEL = "gemma-2-9b"
INSTRUCT_MODEL = "gemma-2-9b-it"

SAE_BASE_RELEASE = 'gemma-scope-9b-pt-res'
SAE_INSTRUCT_RELEASE = 'gemma-scope-9b-it-res'

USE_CANONICAL = True
if USE_CANONICAL:
    SAE_BASE_RELEASE = SAE_BASE_RELEASE + '-canonical'
    SAE_INSTRUCT_RELEASE = SAE_INSTRUCT_RELEASE + '-canonical'

SAE_INSTRUCT_LAYERS = [9, 20, 31]
SAE_WIDTH = '16k'
SAE_DTYPE = torch.bfloat16

# Specify L0 values when using non-canonical SAEs
base_sae_l0_values = {
    '16k': [51, 57, 63],
    # '131k': []
}

instruct_sae_l0_values = {
    '16k': [47, 47, 76],
    # '131k': []
}

if USE_CANONICAL:
    base_sae_ids = [f'layer_{layer}/width_{SAE_WIDTH}/canonical' for layer in SAE_INSTRUCT_LAYERS]
    instruct_sae_ids = base_sae_ids[:]
else:
    base_sae_ids = [f'layer_{layer}/width_{SAE_WIDTH}/average_l0_{l0}' for layer, l0 in zip(SAE_INSTRUCT_LAYERS, base_sae_l0_values[SAE_WIDTH])]
    instruct_sae_ids = [f'layer_{layer}/width_{SAE_WIDTH}/average_l0_{l0}' for layer, l0 in zip(SAE_INSTRUCT_LAYERS, instruct_sae_l0_values[SAE_WIDTH])]

base_sae_ids, instruct_sae_ids

(['layer_9/width_16k/canonical',
  'layer_20/width_16k/canonical',
  'layer_31/width_16k/canonical'],
 ['layer_9/width_16k/canonical',
  'layer_20/width_16k/canonical',
  'layer_31/width_16k/canonical'])

In [4]:
DATAPATH_STR = '../data'

datapath = Path(DATAPATH_STR)
datapath

PosixPath('../data')

In [3]:
from enum import Enum

class Experiment(Enum):
    SUBSTITUTION_LOSS = 'SubstitutionLoss'
    L0_LOSS = 'L0_loss'

TOKENS_SAMPLE = {
    Experiment.SUBSTITUTION_LOSS: [],
    Experiment.L0_LOSS: [],
}

TOTAL_BATCHES = {
    Experiment.SUBSTITUTION_LOSS: 50,
    Experiment.L0_LOSS: 50,
}

def get_batch_size(key: Experiment):
    return TOTAL_BATCHES[key]

def get_tokens_sample(key: Experiment):
    return TOKENS_SAMPLE[key]

def set_tokens_sample(key: Experiment, token_sample):
    TOKENS_SAMPLE[key] = token_sample

## Evaluation utilities

In [None]:
def L0_loss(x, threshold=1e-8):
    """
    Expects a tensor x of shape [N_TOKENS, N_SAE].

    Returns a scalar representing the mean value of activated features (i.e. values across the N_SAE dimensions bigger than
    the threshhold), a.k.a. L0 loss.
    """
    return (x > threshold).float().sum(-1).mean()

def compute_substitution_loss(tokens, tokens_mask, model, sae, sae_layer, reconstruction_metric=None):
    '''
    Expects a tensor of input tokens of shape [N_BATCHES, N_CONTEXT] and a tokens_mask
    of shape [N_BATCHES, N_CONTEXT] where 1 indicates positions to exclude from substitution.

    Returns two losses:
    1. Clean loss - loss of the normal forward pass of the model at the input tokens.
    2. Substitution loss - loss when substituting SAE reconstructions of the residual stream at the SAE layer.
    '''
    assert tokens_mask.shape == tokens.shape, "tokens_mask shape must match tokens shape."

    # Run the model with cache to get the original activations and clean loss
    loss_clean, cache = model.run_with_cache(tokens, names_filter=[sae_layer], return_type="loss")

    # Fetch and detach the original activations
    original_activations = cache[sae_layer]  # shape: [batch_size, seq_len, d_model]

    # Flatten the tokens_mask to match activations shape: [batch_size * seq_len]
    tokens_mask_flat = tokens_mask.view(-1)

    # Flatten activations: [batch_size * seq_len, d_model]
    activations_flat = original_activations.view(-1, original_activations.shape[-1])

    # Filter activations using the inverse of tokens_mask (i.e., positions where mask is 0 are kept for substitution)
    valid_activations = activations_flat[tokens_mask_flat == 0]

    # Get the SAE reconstructed activations
    post_reconstructed = sae.forward(valid_activations)  # shape: [valid_activations, d_model]

    # Update the reconstruction quality metric, if provided
    if reconstruction_metric:
        reconstruction_metric.update(post_reconstructed.flatten().float(), valid_activations.flatten().float())

    # Free unused variables early to save memory
    del original_activations, valid_activations, cache
    clear_cache()

    # Modified hook function using tokens_mask
    def hook_function(activations, hook, new_activations, tokens_mask):
        # activations: [batch_size, seq_len, d_model]
        # new_activations: [valid_activations, d_model]
        # tokens_mask: [batch_size, seq_len]

        # Flatten activations and mask
        activations_flat = activations.view(-1, activations.shape[-1])
        tokens_mask_flat = tokens_mask.view(-1)

        # Replace activations where tokens_mask is 0 (i.e., not masked)
        activations_flat[tokens_mask_flat == 0] = new_activations

        # Reshape back to original shape
        activations = activations_flat.view(activations.shape)

        return activations

    # Run the model again with hooks to substitute activations at the SAE layer using tokens_mask
    loss_reconstructed = model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(sae_layer, partial(hook_function, new_activations=post_reconstructed, tokens_mask=tokens_mask))]
    )

    # Clean up the reconstructed activations and clear memory
    del post_reconstructed
    clear_cache()

    return loss_clean, loss_reconstructed


## Loading the models

In [17]:
model = HookedSAETransformer.from_pretrained(INSTRUCT_MODEL, device=device, dtype=torch.bfloat16)
model



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-9b-it into HookedTransformer


HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-41): 42 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

In [14]:
base_saes = [SAE.from_pretrained(
                release = SAE_BASE_RELEASE,
                sae_id = sae_id,
                device = device
            )[0].to(SAE_DTYPE) for sae_id in base_sae_ids]
base_saes[0].cfg.__dict__

params.npz:   0%|          | 0.00/470M [00:00<?, ?B/s]

params.npz:   0%|          | 0.00/470M [00:00<?, ?B/s]

SAEConfig(architecture='jumprelu', d_in=3584, d_sae=16384, activation_fn_str='relu', apply_b_dec_to_input=False, finetuning_scaling_factor=False, context_size=1024, model_name='gemma-2-9b', hook_name='blocks.9.hook_resid_post', hook_layer=9, hook_head_index=None, prepend_bos=True, dataset_path='monology/pile-uncopyrighted', dataset_trust_remote_code=True, normalize_activations=None, dtype='torch.bfloat16', device='cuda:2', sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gemma-2-9b/9-gemmascope-res-16k', model_from_pretrained_kwargs={})

## PT SAEs on the Instruct model

### L0 Loss

### Substitution/Delta loss