### Setup

In [1]:
try:
    import google.colab
    %pip install sae-lens transformer-lens torcheval
    IN_COLAB = True
except ImportError:
    IN_COLAB = False



In [2]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path

# GPU memory saver (this script doesn't need gradients computation)
torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


### Config

In [3]:
# define the model to work with
MODEL = 'GPT2' # GEMMA, MISTRAL, GPT2

if MODEL == 'GEMMA':
    # Base model stuff
    BASE_MODEL = "google/gemma-2b"
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    BASE_TOKENIZER_NAME = BASE_MODEL

    # Finetuned model stuff
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gemma-2b-res-jb'
    hook_part = 'post'
    layer_num = 6
elif MODEL == 'MISTRAL':
    # Base model stuff
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    # Finetuned model stuff
    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B'
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    # SAE stuff
    RELEASE = 'mistral-7b-res-wg'
    hook_part = 'pre'
    layer_num = 8
elif MODEL == 'GPT2':
    # Base model stuff
    BASE_MODEL = "gpt2-small"
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = 'openai-community/gpt2'

    # Finetuned model stuff
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gpt2-small-res-jb'
    hook_part = 'pre'
    layer_num = 6

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

# Choose saving names as strings without slashes
saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

In [4]:
if IN_COLAB:
    datapath = Path('/content/drive/My Drive/sae_data')
else:
    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

In [5]:
from enum import Enum

class Experiment(Enum):
    SUBSTITUTION_LOSS = 'SubstitutionLoss'
    L0_LOSS = 'L0_loss'
    FEATURE_ACTS = 'FeatureActs'
    FEATURE_DENSITY = 'FeatureDensity'

TOTAL_BATCHES = {
    Experiment.SUBSTITUTION_LOSS: 50,
    Experiment.L0_LOSS: 50,
    Experiment.FEATURE_ACTS: 10,
    Experiment.FEATURE_DENSITY: 50
}

if MODEL == 'MISTRAL' or MODEL == 'GPT2':
    batch_size_prompts = 20
else:
    batch_size_prompts = 5

In [6]:
TOKENS_SAMPLE = {
    Experiment.SUBSTITUTION_LOSS: [],
    Experiment.L0_LOSS: [],
    Experiment.FEATURE_ACTS: [],
    Experiment.FEATURE_DENSITY: []
}

def get_batch_size(key: Experiment):
    return TOTAL_BATCHES[key]

def get_tokens_sample(key: Experiment):
    return TOKENS_SAMPLE[key]

def set_tokens_sample(key: Experiment, token_sample):
    TOKENS_SAMPLE[key] = token_sample

In [7]:
COLAB_BASE_PATH = '/content/drive/My Drive/sae_data'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    datapath = Path(COLAB_BASE_PATH)
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Utils

#### Loading finetuned model

In [8]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def adjust_state_dict(model, base_model_vocab_size):
    """Adjust the state_dict of the model to match the base model's vocab size."""
    state_dict = model.state_dict()

    # Adjust the embedding matrix
    if state_dict['model.embed_tokens.weight'].shape[0] > base_model_vocab_size:
        state_dict['model.embed_tokens.weight'] = state_dict['model.embed_tokens.weight'][:base_model_vocab_size, :]

    # Adjust the unembedding (lm_head) matrix
    if state_dict['lm_head.weight'].shape[0] > base_model_vocab_size:
        state_dict['lm_head.weight'] = state_dict['lm_head.weight'][:base_model_vocab_size, :]

    return state_dict

def load_hf_model(path, base_model=BASE_MODEL, device='cuda', dtype=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)

    # Adjust the model's state dict to match the base model's vocab size
    if base_model == 'mistral-7b':
      base_model_vocab_size = 32000  # Mistral 7B base vocab size
      adjusted_state_dict = adjust_state_dict(model, base_model_vocab_size)

      # Adjust model architecture to match the new vocab size
      model.resize_token_embeddings(base_model_vocab_size)

      # Load the adjusted state dict back into the model
      model.load_state_dict(adjusted_state_dict, strict=False)

    # Now load the fine-tuned model into the HookedSAETransformer
    finetune_model = HookedSAETransformer.from_pretrained(
        base_model, device=device, hf_model=model, dtype=dtype
    )

    del model  # offload the HF model as it's already wrapped into HookedSAETransformer (finetune_model)
    clear_cache()

    return tokenizer, finetune_model

#### Outlier filtering utility

In [9]:
import json
import sys
import os

def load_outliers_config(filename='outlier_cfg.json', g_drive_folder='My Drive'):
    """
    This function checks if the script is running in Google Colab and loads the JSON file accordingly.
    If running in Colab, it will mount Google Drive and load the file from there.
    Otherwise, it will load the file from a local directory.
    """
    if not IN_COLAB:
        # If not in Colab, use local folder
        # Assuming this is being run from the 'notebooks' folder
        sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

        from saetuning.utils import OUTLIERS_CFG
        return OUTLIERS_CFG

    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    file_path = os.path.join('/content/drive', g_drive_folder, filename)
    print(f"Loading JSON file from Google Drive: {file_path}")

    # Load the JSON data
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return None

In [10]:
OUTLIERS_CFG = load_outliers_config()

def get_norm_scalar(model_name):
    return OUTLIERS_CFG.get("norm_scalar", {}).get(model_name, None)

def get_threshold_multiplier(model_name):
    return OUTLIERS_CFG.get("threshhold_multiplier", {}).get(model_name, None)

def get_base_threshhold(model_name):
    return OUTLIERS_CFG.get("base_threshhold", {}).get(model_name, None)

def get_absolute_threshhold(model_name):
    return OUTLIERS_CFG.get("absolute_threshold", {}).get(model_name, None)

OUTLIERS_CFG

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Loading JSON file from Google Drive: /content/drive/My Drive/outlier_cfg.json


{'norm_scalar': {'google/gemma-2b': 0.31278620989943556,
  'gpt2-small': 0.27139524668485193,
  'mistral-7b': 14.178454680291779},
 'threshhold_multiplier': {'google/gemma-2b': 2, 'gpt2-small': 2},
 'base_threshhold': {'google/gemma-2b': 45.254833995939045,
  'gpt2-small': 27.712812921102035},
 'absolute_threshold': {'mistral-7b': 200}}

In [11]:
# Auxilary method for getting a mask of outlier activations
def is_act_outlier(act_tensor, model_name):
    """
    Expects act_tensor of shape [*, D_MODEL]

    Returns a boolean tensor of shape [*], where for each batch position we report whether the corresponding activation
    exceeds the outlier threshold that is defined as

    threshold = threshold_multiplier * base_threshold, where
    base_threshold = sqrt(D_MODEL)
    --
    OR, if the **absolute_threshhold** is provided in the cfg
    --
    threshold = absolute_threshhold,

    The first case is meant to be used with activations in the normalized scale, i.e. scaled in such a way,
    that their average norm is equal to sqrt(D_MODEL). To do this normalization, we multiple by norm_scalar
    of the corresponding model, which is computed in the find_outlier_norms.ipynb.

    Check this blog-post for more details: https://www.lesswrong.com/posts/fmwk6qxrpW8d4jvbd/saes-usually-transfer-between-base-and-chat-models
    """

    absolute_threshhold = get_absolute_threshhold(model_name)

    norm_scalar = get_norm_scalar(model_name)
    threshold_multiplier = get_threshold_multiplier(model_name)
    base_threshold = get_base_threshhold(model_name)

    # Define the threshold value and the activations scale, depending on the threshold type
    if absolute_threshhold:
        threshold = absolute_threshhold
        scaled_act = act_tensor
    else: # relative threshold case
        threshold = threshold_multiplier * base_threshold
        scaled_act = norm_scalar * act_tensor

    scaled_act_norms = torch.norm(scaled_act, dim=-1)

    return scaled_act_norms > threshold

# Main outlier filtering method
def filter_activations(acts, model_name=BASE_MODEL, return_mask=False):
    """
    Filters out activations based on outlier norms and returns the filtered activations.

    Args:
        acts (torch.Tensor): A tensor of activations with shape [BATCH, SEQ, D_MODEL].
        model_name (str): The name of the model used to determine the threshold for filtering out outlier activations.
        return_mask (bool): If True, returns the 2D boolean mask indicating which activations were retained. The mask has shape [BATCH, SEQ].

    Returns:
        torch.Tensor: A tensor of filtered activations with shape [N_VALID_ACTIVATIONS, D_MODEL], where N_VALID_ACTIVATIONS <= BATCH * SEQ.
        torch.Tensor (optional): A 2D boolean tensor of shape [BATCH, SEQ] representing the filtering mask, indicating whether each activation was retained (True) or filtered out (False).

    Notes:
        - The function removes activations identified as outliers by `is_act_outlier`. The activations that pass the filter are flattened into a tensor of shape [N_VALID_ACTIVATIONS, D_MODEL].
        - If `return_mask=True`, the function also returns a 2D boolean mask corresponding to the [BATCH, SEQ] dimensions of the original activations. This mask can be useful for tracking which activations were kept.
        - The returned filtered activations are flattened across both batch and sequence dimensions. If reshaping back to a sequence or batch structure is required, you will need to do this outside the function based on the original mask.
    """
    # Get the outlier mask
    is_outlier_mask = is_act_outlier(acts, model_name)  # [BATCH, SEQ]

    # Expand the mask to match the last dimension (D_MODEL) for correct filtering
    expanded_mask = is_outlier_mask.unsqueeze(-1).expand_as(acts)  # [BATCH, SEQ, D_MODEL]

    # Apply the mask and filter out the outlier activations
    filtered_acts = acts[~expanded_mask].reshape(-1, acts.shape[-1])  # Flatten only the valid activations, retaining D_MODEL

    if return_mask:
        # Return the 2D mask corresponding to the original [BATCH, SEQ] shape
        filter_mask = ~is_outlier_mask  # Keep it as 2D: [BATCH, SEQ]
        return filtered_acts, filter_mask
    else:
        return filtered_acts

#### Eval functions

In [12]:
# @title
import torch
import torch.nn.functional as F
from enum import Enum
import numpy as np
from scipy.stats import gamma
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

#### Quantitave SAE evaluation ####
def L0_loss(x, threshold=1e-8):
    """
    Expects a tensor x of shape [N_TOKENS, N_SAE].

    Returns a scalar representing the mean value of activated features (i.e. values across the N_SAE dimensions bigger than
    the threshhold), a.k.a. L0 loss.
    """
    return (x > threshold).float().sum(-1).mean()

def get_substitution_loss(tokens, model, sae, sae_layer, reconstruction_metric=None):
    '''
    Expects a tensor of input tokens of shape [N_BATCHES, N_CONTEXT].

    Returns two losses:
    1. Clean loss - loss of the normal forward pass of the model at the input tokens.
    2. Substitution loss - loss when substituting SAE reconstructions of the residual stream at the SAE layer of the model.
    '''
    # Run the model with cache to get the original activations and clean loss
    loss_clean, cache = model.run_with_cache(tokens, names_filter=[sae_layer], return_type="loss")

    # Fetch and detach the original activations
    original_activations = cache[sae_layer]

    # Apply activation filtering
    activations_filtered, filter_mask = filter_activations(original_activations, return_mask=True)
    # Shape of activations_filtered is now [valid_activations, d_model]

    # Get the SAE reconstructed activations
    post_reconstructed = sae.forward(activations_filtered)# shape [valid_activations, d_model]

    # Update the reconstruction quality metric, if provided
    if reconstruction_metric:
        reconstruction_metric.update(post_reconstructed.flatten().float(), activations_filtered.flatten().float())

    # Free unused variables early to save memory
    del original_activations, activations_filtered, cache
    clear_cache()

    # Modified hook function
    def hook_function(activations, hook, new_activations, filter_mask):
        # activations: [batch_size, seq_len, d_model]
        # filter_mask: [batch_size, seq_len]
        # new_activations: [valid_activations, d_model]

        # Flatten activations and filter_mask
        activations_flat = activations.view(-1, activations.shape[-1])
        filter_mask_flat = filter_mask.view(-1)

        # Replace activations at positions specified by filter_mask
        activations_flat[filter_mask_flat] = new_activations

        # Reshape back to original shape
        activations = activations_flat.view(activations.shape)

        return activations

    post_reconstructed = post_reconstructed.half() # Reduce to fp16 because we'll splice it in to the model

    # Run the model again with hooks to substitute activations at the SAE layer
    loss_reconstructed = model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(sae_layer, partial(hook_function, new_activations=post_reconstructed, filter_mask=filter_mask))]
    )

    # Clean up the reconstructed activations and clear memory
    del post_reconstructed
    clear_cache()

    return loss_clean, loss_reconstructed

import plotly.graph_objs as go
from functools import partial

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """
    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    scale_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=f"SAE Features {y_value} histogram ({first_bin_name}: {first_bin_value:.2e})",
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Density",
        yaxis_range=[0, scale_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

class FeatureDensityPlotter:
    def __init__(self, n_features, n_tokens, activation_threshold=1e-10, num_bins=100):
        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon)

### Task 4.1 Pretrained case

In [13]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


In [14]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

blocks.6.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


{'model_name': 'gpt2-small',
 'hook_point': 'blocks.6.hook_resid_pre',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'Skylion007/openwebtext',
 'is_dataset_tokenized': False,
 'context_size': 128,
 'use_cached_activations': False,
 'cached_activations_path': 'activations/Skylion007_openwebtext/gpt2-small/blocks.6.hook_resid_pre',
 'd_in': 768,
 'n_batches_in_buffer': 128,
 'total_training_tokens': 300000000,
 'store_batch_size': 32,
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'b_dec_init_method': 'geometric_median',
 'expansion_factor': 32,
 'from_pretrained_path': None,
 'l1_coefficient': 8e-05,
 'lr': 0.0004,
 'lr_scheduler_name': None,
 'lr_warm_up_steps': 5000,
 'train_batch_size': 4096,
 'use_ghost_grads': False,
 'feature_sampling_window': 1000,
 'feature_sampling_method': None,
 'resample_batches': 1028,
 'feature_reinit_scale': 0.2,
 'dead_feature_window': 5000,
 'dead_feature_estimation_method': 'no_fire',
 'dead_feature_threshold': 1

In [15]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

'relu'

In [16]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens

Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]



(20, 2560)

#### 4.1.1 L0 loss

In [17]:
from tqdm import tqdm

all_L0 = []

total_batches = get_batch_size(Experiment.L0_LOSS)
all_tokens_L0 = get_tokens_sample(Experiment.L0_LOSS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_L0.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

# Concatenate all tokens into a single tensor for reuse
set_tokens_sample(Experiment.L0_LOSS, torch.cat(all_tokens_L0))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

  0%|          | 0/50 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 50/50 [00:13<00:00,  3.73it/s]


In [18]:
torch.tensor(all_L0).mean()

tensor(49.4706)

#### 4.1.2 Substitution Loss

In [17]:
from tqdm import tqdm
from torcheval.metrics import R2Score
sae_reconstruction_metric = R2Score().to(device)

all_SL_clean = []
all_SL_reconstructed = []

total_batches = get_batch_size(Experiment.SUBSTITUTION_LOSS)
all_tokens_SL = get_tokens_sample(Experiment.SUBSTITUTION_LOSS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    # Store tokens for later reuse
    all_tokens_SL.append(tokens)

    clean_loss, reconstructed_loss = get_substitution_loss(tokens, base_model, sae, sae_id, sae_reconstruction_metric)

    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

# Concatenate all tokens into a single tensor for reuse
set_tokens_sample(Experiment.SUBSTITUTION_LOSS, torch.cat(all_tokens_SL))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [00:38<00:00,  1.31it/s]


In [18]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(3.552734375, 3.6953125)

In [19]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


0.93325275182724

#### 4.1.3 Feature activations histogram

In [20]:
all_feature_acts = []

total_batches = get_batch_size(Experiment.FEATURE_ACTS)
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_ACTS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    all_histogram_tokens.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id] # [N_BATCH, N_CONTEXT, D_SAE]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

set_tokens_sample(Experiment.FEATURE_ACTS, torch.cat(all_histogram_tokens))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 10/10 [00:10<00:00,  1.08s/it]


In [21]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

In [22]:
del all_feature_acts
clear_cache()

#### 4.1.4 Feature density histogram

In [23]:
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_DENSITY)
total_batches = get_batch_size(Experiment.FEATURE_DENSITY)

total_tokens = total_batches * batch_size_tokens
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    all_histogram_tokens.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache and convert to float32 for more accurate density computation
    activations_original = cache[sae_id].flatten(0, 1).float() # [N_BATCH, N_CONTEXT, D_SAE]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

set_tokens_sample(Experiment.FEATURE_DENSITY, torch.cat(all_histogram_tokens))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [00:49<00:00,  1.00it/s]


In [24]:
density_plotter.plot(y_scalar=2, y_scale_bin=-2)

In [25]:
from pathlib import Path
# Save the computed feature densities
base_feature_densities = density_plotter.feature_densities
base_feature_densities_fname = f'Feature_densities_{saving_name_base}_on_{saving_name_ds}.pt'

saving_path = datapath / base_feature_densities_fname
torch.save(base_feature_densities, saving_path)

In [21]:
# Save the token samples

def save_tokens_sample(tokens_dict, save_dir):
    for key, tensor in tokens_dict.items():
        if isinstance(tensor, torch.Tensor):
            filename = f"{saving_name_base}_{key.name}_tokens_{saving_name_ds}.pt"
            path = os.path.join(save_dir, filename)

            torch.save(tensor, path)
            print(f"Saved {key.name} tensor to {path}")
        else:
          print(f"WARNING: Unsupported saving type: {type(tensor)} for key {key.name}; Skipping this object.")

save_tokens_sample(TOKENS_SAMPLE, datapath)

Saved L0_LOSS tensor to /content/drive/My Drive/sae_data/gpt2-small_L0_LOSS_tokens_openwebtext.pt


In [20]:
del base_model, activation_store
clear_cache()

### Task 4.2 FineTuned case

In [13]:
# Load the finetune model and its tokenizer
finetune_tokenizer, finetune_model = load_hf_model(FINETUNE_PATH if FINETUNE_PATH is not None else FINETUNE_MODEL,
                                                   device=device, dtype=torch.float16)

tokenizer_config.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/850k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/508k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


In [14]:
# import the required libraries
from sae_lens import SAE

LOAD_SAE = True

if LOAD_SAE:
  sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}' # Gemma is post,
  sae, cfg_dict, sparsity = SAE.from_pretrained(
                              release = RELEASE,
                              sae_id = sae_id,
                              device = device
  )
  cfg_dict

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [15]:
LOAD_TOKENS = True

if LOAD_TOKENS:
  # Load the stored token samples
  def load_tokens_sample(tokens_dict, save_dir):
    for key in tokens_dict.keys():
        filename = f"{saving_name_base}_{key.name}_tokens_{saving_name_ds}.pt"
        path = os.path.join(save_dir, filename)

        if os.path.exists(path):
            tokens_dict[key] = torch.load(path)
            print(f"Loaded {key.name} tensor from {path}")
        else:
            print(f"File {path} does not exist. Skipping...")

  TOKENS_SAMPLE = {
    Experiment.SUBSTITUTION_LOSS: [],
    Experiment.L0_LOSS: [],
    Experiment.FEATURE_ACTS: [],
    Experiment.FEATURE_DENSITY: []
  }
  load_tokens_sample(TOKENS_SAMPLE, datapath)

  def get_tokens_sample(key: Experiment):
      return TOKENS_SAMPLE[key]

  def set_tokens_sample(key: Experiment, token_sample):
    TOKENS_SAMPLE[key] = token_sample

TOKENS_SAMPLE

File /content/drive/My Drive/sae_data/gpt2-small_SUBSTITUTION_LOSS_tokens_openwebtext.pt does not exist. Skipping...
Loaded L0_LOSS tensor from /content/drive/My Drive/sae_data/gpt2-small_L0_LOSS_tokens_openwebtext.pt
File /content/drive/My Drive/sae_data/gpt2-small_FEATURE_ACTS_tokens_openwebtext.pt does not exist. Skipping...
File /content/drive/My Drive/sae_data/gpt2-small_FEATURE_DENSITY_tokens_openwebtext.pt does not exist. Skipping...


  tokens_dict[key] = torch.load(path)


{<Experiment.SUBSTITUTION_LOSS: 'SubstitutionLoss'>: [],
 <Experiment.L0_LOSS: 'L0_loss'>: tensor([[50256, 13924,    12,  ...,   379,   262,  4436],
         [50256,   351,  3126,  ...,  1074,   468,  9167],
         [50256,   262, 21402,  ..., 34064,  4436, 21596],
         ...,
         [50256,   447,   251,  ...,  7082,   284,   262],
         [50256,  2683,    11,  ...,   621,   674,  2717],
         [50256,  3259,    13,  ...,   564,   250,    40]], device='cuda:0'),
 <Experiment.FEATURE_ACTS: 'FeatureActs'>: [],
 <Experiment.FEATURE_DENSITY: 'FeatureDensity'>: []}

#### 4.2.1 L0 loss

In [16]:
from tqdm import tqdm

all_L0 = []

total_batches = get_batch_size(Experiment.L0_LOSS)
all_tokens_L0 = get_tokens_sample(Experiment.L0_LOSS)
print(f'Tokens count: ', all_tokens_L0.numel())

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the L0 loss.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_L0[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

Tokens count:  128000


100%|██████████| 50/50 [00:13<00:00,  3.76it/s]


In [17]:
torch.tensor(all_L0).mean()

tensor(74.1430)

In [18]:
clear_cache()

#### 4.2.2 Substitution Loss

In [19]:
from tqdm import tqdm
from torcheval.metrics import R2Score
sae_reconstruction_metric = R2Score().to(device)

all_SL_clean = []
all_SL_reconstructed = []

total_batches = get_batch_size(Experiment.SUBSTITUTION_LOSS)
all_tokens_SL = get_tokens_sample(Experiment.SUBSTITUTION_LOSS)
print(f'Tokens count: ', all_tokens_SL.numel())

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the losses.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_SL[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, finetune_model, sae, sae_id, sae_reconstruction_metric)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

Tokens count:  128000


100%|██████████| 50/50 [00:31<00:00,  1.58it/s]


In [20]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(10.203125, 10.078125)

In [21]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


0.8161449432373047

In [22]:
loss_reconstructed_tensor = torch.tensor(all_SL_reconstructed)
loss_reconstructed_tensor.sort()

torch.return_types.sort(
values=tensor([ 8.7266,  9.6094,  9.8047,  9.8047,  9.8203,  9.8672,  9.9062,  9.9297,
         9.9531,  9.9688,  9.9688,  9.9922, 10.0312, 10.0312, 10.0312, 10.0312,
        10.0312, 10.0469, 10.0469, 10.0547, 10.0625, 10.0781, 10.0859, 10.0938,
        10.1172, 10.1328, 10.1328, 10.1328, 10.1328, 10.1406, 10.1562, 10.1562,
        10.1562, 10.1641, 10.1641, 10.1719, 10.1953, 10.1953, 10.1953, 10.2266,
        10.2422, 10.2500, 10.2578, 10.2734, 10.2891, 10.3281, 10.3438, 10.3828,
        10.4062, 10.4688], dtype=torch.float16),
indices=tensor([33, 31, 49, 30, 23, 11,  2, 35, 10, 41, 37,  3, 47, 13, 26, 36, 18, 42,
        12, 29, 43, 48, 32, 28,  5, 38, 40, 25, 19, 20, 44, 39, 16,  8, 14,  6,
         0, 22, 24, 34, 15,  1, 27, 17, 45,  4, 46,  7,  9, 21]))

In [23]:
# Filter out NaN values (if there are any)
filtered_loss_reconstructed = loss_reconstructed_tensor[~torch.isinf(loss_reconstructed_tensor)]
print(f'Filtered substitution loss = {filtered_loss_reconstructed.mean().item()}')

Filtered substitution loss = 10.078125


#### 4.2.3 Feature activations histogram

In [24]:
from tqdm import tqdm

all_feature_acts = []

total_batches = get_batch_size(Experiment.FEATURE_ACTS)
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_ACTS)
print(f'Tokens count: ', all_histogram_tokens.numel())

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the histogram
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

Tokens count:  25600


100%|██████████| 10/10 [00:04<00:00,  2.05it/s]


In [25]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

In [26]:
del all_feature_acts
clear_cache()

In [27]:
all_histogram_tokens.numel()

25600

#### 4.2.4 Feature density histogram

In [28]:
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_DENSITY)
total_batches = get_batch_size(Experiment.FEATURE_DENSITY)

print(f'Tokens count: ', all_histogram_tokens.numel())

total_tokens = all_histogram_tokens.numel()
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the histogram
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                             names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id].flatten(0, 1).float()
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    clear_cache()

Tokens count:  128000


100%|██████████| 50/50 [00:55<00:00,  1.10s/it]


In [29]:
density_plotter.plot(y_scalar=1.5, y_scale_bin=-1)

In [30]:
# Save the computed feature densities
finetune_feature_densities = density_plotter.feature_densities

# Choose saving names consistent with saetuning/get_scores.py
finetune_feature_densities_fname = f'Feature_densities_{saving_name_ft}_on_{saving_name_ds}.pt'

saving_path = datapath / finetune_feature_densities_fname
torch.save(finetune_feature_densities, saving_path)