In [None]:
%pip install sae-lens transformer-lens torcheval



In [None]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [None]:
# define the model to work with
MODEL = 'MISTRAL' # GEMMA, GPT2

if MODEL == 'GEMMA':
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    FINETUNE_PATH = None
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'post'
    layer_num = 6
elif MODEL == 'GPT2':
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'pre'
    layer_num = 6
elif MODEL == 'MISTRAL':
    RELEASE = 'mistral-7b-res-wg'
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B' #DeepMount00/Mistral-Ita-7b
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    hook_part = 'pre'
    layer_num = 8

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

In [None]:
from enum import Enum

class Experiment(Enum):
    SUBSTITUTION_LOSS = 'SubstitutionLoss'
    L0_LOSS = 'L0_loss'
    FEATURE_ACTS = 'FeatureActs'
    FEATURE_DENSITY = 'FeatureDensity'

TOTAL_BATCHES = {
    Experiment.SUBSTITUTION_LOSS: 25,
    Experiment.L0_LOSS: 50,
    Experiment.FEATURE_ACTS: 25,
    Experiment.FEATURE_DENSITY: 50
}

TOKENS_SAMPLE = {
    Experiment.SUBSTITUTION_LOSS: [],
    Experiment.L0_LOSS: [],
    Experiment.FEATURE_ACTS: [],
    Experiment.FEATURE_DENSITY: []
}

def get_batch_size(key: Experiment):
    return TOTAL_BATCHES[key]

def get_tokens_sample(key: Experiment):
    return TOKENS_SAMPLE[key]

def set_tokens_sample(key: Experiment, token_sample):
    TOKENS_SAMPLE[key] = token_sample

### Utils

#### Loading finetune model

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def adjust_state_dict(model, base_model_vocab_size):
    """Adjust the state_dict of the model to match the base model's vocab size."""
    state_dict = model.state_dict()

    # Adjust the embedding matrix
    if state_dict['model.embed_tokens.weight'].shape[0] > base_model_vocab_size:
        state_dict['model.embed_tokens.weight'] = state_dict['model.embed_tokens.weight'][:base_model_vocab_size, :]

    # Adjust the unembedding (lm_head) matrix
    if state_dict['lm_head.weight'].shape[0] > base_model_vocab_size:
        state_dict['lm_head.weight'] = state_dict['lm_head.weight'][:base_model_vocab_size, :]

    return state_dict

def load_hf_model(path, base_model=BASE_MODEL, device='cuda', dtype=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)

    # Adjust the model's state dict to match the base model's vocab size
    base_model_vocab_size = 32000  # Get base model vocab size
    adjusted_state_dict = adjust_state_dict(model, base_model_vocab_size)

    # Adjust model architecture to match the new vocab size
    model.resize_token_embeddings(base_model_vocab_size)

    # Load the adjusted state dict back into the model
    model.load_state_dict(adjusted_state_dict, strict=False)

    # Now load the fine-tuned model into the HookedSAETransformer
    finetune_model = HookedSAETransformer.from_pretrained(
        base_model, device=device, hf_model=model, dtype=dtype
    )

    del model  # offload the HF model as it's already wrapped into HookedSAETransformer (finetune_model)
    clear_cache()

    return tokenizer, finetune_model

#### Activations filtering utility

In [None]:
import json
import sys
import os

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

def load_outliers_config(filename='outlier_cfg.json'):
    """
    This function checks if the script is running in Google Colab and loads the JSON file accordingly.
    If running in Colab, it will mount Google Drive and load the file from there.
    Otherwise, it will load the file from a local directory.
    """
    if not IN_COLAB:
        # If not in Colab, use local folder
        # Assuming this is being run from the 'notebooks' folder
        sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

        from saetuning.utils import OUTLIERS_CFG
        return OUTLIERS_CFG

    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Define the path to your JSON file in Google Drive
    file_path = os.path.join('/content/drive/My Drive', filename)
    print(f"Loading JSON file from Google Drive: {file_path}")

    # Load the JSON data
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return None

In [None]:
OUTLIERS_CFG = load_outliers_config()

def get_norm_scalar(model_name):
    return OUTLIERS_CFG.get("norm_scalar", {}).get(model_name, None)

def get_threshold_multiplier(model_name):
    return OUTLIERS_CFG.get("threshhold_multiplier", {}).get(model_name, None)

def get_base_threshhold(model_name):
    return OUTLIERS_CFG.get("base_threshhold", {}).get(model_name, None)

def get_absolute_threshhold(model_name):
    return OUTLIERS_CFG.get("absolute_threshold", {}).get(model_name, None)

# Auxilary method for getting a mask of outlier activations
def is_act_outlier(act_tensor, model_name):
    """
    Expects act_tensor of shape [*, D_MODEL]

    Returns a boolean tensor of shape [*], where for each batch position we report whether the corresponding activation
    exceeds the outlier threshold that is defined as

    threshold = threshold_multiplier * base_threshold, where
    base_threshold = sqrt(D_MODEL)

    Important! This threshold value is in the normalized scale, i.e. is meant to be used for activations that are scaled
    in such a way, that their average norm is equal to sqrt(D_MODEL). To do this normalization, we multiple by norm_scalar
    of the corresponding model.

    Check this blog-post for more details: https://www.lesswrong.com/posts/fmwk6qxrpW8d4jvbd/saes-usually-transfer-between-base-and-chat-models
    """
    norm_scalar = get_norm_scalar(model_name)
    threshold_multiplier = get_threshold_multiplier(model_name)
    base_threshold = get_base_threshhold(model_name)
    absolute_threshhold = get_absolute_threshhold(model_name)

    if absolute_threshhold:
        threshold = norm_scalar * absolute_threshhold
    else:
        threshold = threshold_multiplier * base_threshold

    scaled_act = norm_scalar * act_tensor
    scaled_act_norms = torch.norm(scaled_act, dim=-1)

    return scaled_act_norms > threshold

OUTLIERS_CFG

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Loading JSON file from Google Drive: /content/drive/My Drive/outlier_cfg.json


{'norm_scalar': {'google/gemma-2b': 0.31278620989943556,
  'gpt2-small': 0.27139524668485193,
  'mistral-7b': 14.178454680291779},
 'threshhold_multiplier': {'google/gemma-2b': 2, 'gpt2-small': 2},
 'base_threshhold': {'google/gemma-2b': 45.254833995939045,
  'gpt2-small': 27.712812921102035},
 'absolute_threshold': {'mistral-7b': 200}}

In [None]:
def filter_activations(acts, model_name=BASE_MODEL, return_mask=False):
    """
    Filters out activations based on outlier norms and returns the filtered activations.

    Args:
        acts (torch.Tensor): A tensor of activations with shape [BATCH, SEQ, D_MODEL].
        model_name (str): The name of the model used to determine the threshold for filtering out outlier activations.
        return_mask (bool): If True, returns the 2D boolean mask indicating which activations were retained. The mask has shape [BATCH, SEQ].

    Returns:
        torch.Tensor: A tensor of filtered activations with shape [N_VALID_ACTIVATIONS, D_MODEL], where N_VALID_ACTIVATIONS <= BATCH * SEQ.
        torch.Tensor (optional): A 2D boolean tensor of shape [BATCH, SEQ] representing the filtering mask, indicating whether each activation was retained (True) or filtered out (False).

    Notes:
        - The function removes activations identified as outliers by `is_act_outlier`. The activations that pass the filter are flattened into a tensor of shape [N_VALID_ACTIVATIONS, D_MODEL].
        - If `return_mask=True`, the function also returns a 2D boolean mask corresponding to the [BATCH, SEQ] dimensions of the original activations. This mask can be useful for tracking which activations were kept.
        - The returned filtered activations are flattened across both batch and sequence dimensions. If reshaping back to a sequence or batch structure is required, you will need to do this outside the function based on the original mask.
    """
    # Get the outlier mask
    is_outlier_mask = is_act_outlier(acts, model_name)  # [BATCH, SEQ]

    # Expand the mask to match the last dimension (D_MODEL) for correct filtering
    expanded_mask = is_outlier_mask.unsqueeze(-1).expand_as(acts)  # [BATCH, SEQ, D_MODEL]

    # Apply the mask and filter out the outlier activations
    filtered_acts = acts[~expanded_mask].reshape(-1, acts.shape[-1])  # Flatten only the valid activations, retaining D_MODEL

    if return_mask:
        # Return the 2D mask corresponding to the original [BATCH, SEQ] shape
        filter_mask = ~is_outlier_mask  # Keep it as 2D: [BATCH, SEQ]
        return filtered_acts, filter_mask
    else:
        return filtered_acts

#### Score functions definition (copy from saetuning/utils.py)

In [None]:
# @title
import torch
import torch.nn.functional as F
from enum import Enum
import numpy as np
from scipy.stats import gamma
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

#### Quantitave SAE evaluation ####
def L0_loss(x, threshold=1e-8):
    """
    Expects a tensor x of shape [N_TOKENS, N_SAE].

    Returns a scalar representing the mean value of activated features (i.e. values across the N_SAE dimensions bigger than
    the threshhold), a.k.a. L0 loss.
    """
    return (x > threshold).float().sum(-1).mean()

def get_substitution_loss(tokens, model, sae, sae_layer, reconstruction_metric=None):
    '''
    Expects a tensor of input tokens of shape [N_BATCHES, N_CONTEXT].

    Returns two losses:
    1. Clean loss - loss of the normal forward pass of the model at the input tokens.
    2. Substitution loss - loss when substituting SAE reconstructions of the residual stream at the SAE layer of the model.
    '''
    # Run the model with cache to get the original activations and clean loss
    loss_clean, cache = model.run_with_cache(tokens, names_filter=[sae_layer], return_type="loss")

    # Fetch and detach the original activations
    original_activations = cache[sae_layer]

    # Apply activation filtering
    activations_filtered, filter_mask = filter_activations(original_activations, return_mask=True)
    # Shape of activations_filtered is now [valid_activations, d_model]

    # Filter the tokens using the same mask
    tokens_filtered = tokens[filter_mask].reshape(activations_filtered.shape[0]) # shape [valid_activations]

    # Get the SAE reconstructed activations
    post_reconstructed = sae.forward(activations_filtered) # shape [valid_activations, d_model]

    # Update the reconstruction quality metric, if provided
    if reconstruction_metric:
        reconstruction_metric.update(post_reconstructed.flatten().float(), activations_filtered.flatten().float())

    # Free unused variables early to save memory
    del original_activations, activations_filtered, cache
    clear_cache()

    # Hook function to substitute activations with SAE reconstructions
    def hook_function(activations, hook, new_activations):
        activations.copy_(new_activations)  # Perform in-place substitution of activations
        return activations

    # Run the model again with hooks to substitute activations at the SAE layer
    loss_reconstructed = model.run_with_hooks(
        tokens_filtered,
        return_type="loss",
        fwd_hooks=[(sae_layer, partial(hook_function, new_activations=post_reconstructed))]
    )

    # Clean up the reconstructed activations and clear memory
    del post_reconstructed
    clear_cache()

    return loss_clean, loss_reconstructed

import plotly.graph_objs as go
from functools import partial

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """
    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    scale_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=f"SAE Features {y_value} histogram ({first_bin_name}: {first_bin_value:.2e})",
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Density",
        yaxis_range=[0, scale_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

class FeatureDensityPlotter:
    def __init__(self, n_features, n_tokens, activation_threshold=1e-10, num_bins=100):
        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon)

### Task 4.1 Pretrained case

In [None]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model mistral-7b into HookedTransformer


In [None]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

{'d_in': 4096,
 'd_sae': 65536,
 'dtype': 'float32',
 'device': 'cuda',
 'model_name': 'mistral-7b',
 'hook_name': 'blocks.8.hook_resid_pre',
 'hook_layer': 8,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'apply_b_dec_to_input': False,
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': False,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 256,
 'normalize_activations': 'constant_norm_rescale',
 'dataset_trust_remote_code': True,
 'architecture': 'standard',
 'neuronpedia': None}

In [None]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

'relu'

In [None]:
from sae_lens import ActivationsStore
batch_size_prompts = 5

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



(5, 1280)

#### 4.1.1 L0 loss

In [None]:
from tqdm import tqdm

all_L0 = []

total_batches = get_batch_size(Experiment.L0_LOSS)
all_tokens_L0 = get_tokens_sample(Experiment.L0_LOSS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_L0.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    # feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
set_tokens_sample(Experiment.L0_LOSS, torch.cat(all_tokens_L0))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [00:11<00:00,  4.19it/s]


In [None]:
torch.tensor(all_L0).mean()

tensor(85.3981)

#### 4.1.2 Substitution Loss

In [None]:
from tqdm import tqdm
from torcheval.metrics import R2Score
sae_reconstruction_metric = R2Score().to(device)

all_SL_clean = []
all_SL_reconstructed = []

total_batches = get_batch_size(Experiment.SUBSTITUTION_LOSS)
all_tokens_SL = get_tokens_sample(Experiment.SUBSTITUTION_LOSS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    # Store tokens for later reuse
    all_tokens_SL.append(tokens)

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, base_model, sae, sae_id, sae_reconstruction_metric)

    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

# Concatenate all tokens into a single tensor for reuse
set_tokens_sample(Experiment.SUBSTITUTION_LOSS, torch.cat(all_tokens_SL))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 25/25 [00:42<00:00,  1.68s/it]


In [None]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(1.7880859375, 2.4296875)

In [None]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


0.6823176145553589

#### 4.1.3 Feature activations histogram

In [None]:
all_feature_acts = []

total_batches = get_batch_size(Experiment.FEATURE_ACTS)
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_ACTS)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    all_histogram_tokens.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id] # [N_BATCH, N_CONTEXT, D_SAE]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

set_tokens_sample(Experiment.FEATURE_ACTS, torch.cat(all_histogram_tokens))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 25/25 [00:11<00:00,  2.19it/s]


In [None]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

In [None]:
del all_feature_acts
clear_cache()

#### 4.1.4 Feature density histogram

In [None]:
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_DENSITY)
total_batches = get_batch_size(Experiment.FEATURE_DENSITY)

total_tokens = total_batches * batch_size_tokens
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    all_histogram_tokens.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache and convert to float32 for more accurate density computation
    activations_original = cache[sae_id].flatten(0, 1).float() # [N_BATCH, N_CONTEXT, D_SAE]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

set_tokens_sample(Experiment.FEATURE_DENSITY, torch.cat(all_histogram_tokens))  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [00:29<00:00,  1.72it/s]


In [None]:
density_plotter.plot(y_scalar=2, y_scale_bin=-2)

In [None]:
# Save the computed feature densities
base_feature_densities = density_plotter.feature_densities

# Choose saving names consistent with saetuning/get_scores.py
saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

base_feature_densities_fname = f'Feature_densities_{saving_name_base}_on_{saving_name_ds}.pt'

if IN_COLAB:
    saving_path = f'./{base_feature_densities_fname}'
else:
    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    saving_path = datapath / base_feature_densities_fname

torch.save(base_feature_densities, saving_path)

In [None]:
del base_model, activation_store
clear_cache()

### Task 4.2 FineTuned case

In [None]:
# Load the finetune model and its tokenizer
finetune_tokenizer, finetune_model = load_hf_model(FINETUNE_PATH if FINETUNE_PATH is not None else FINETUNE_MODEL,
                                                   device=device, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]



Loaded pretrained model mistral-7b into HookedTransformer


In [None]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}' # Gemma is post,

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

#### 4.2.1 L0 loss

In [None]:
from tqdm import tqdm

all_L0 = []

total_batches = get_batch_size(Experiment.L0_LOSS)
all_tokens_L0 = get_tokens_sample(Experiment.L0_LOSS)

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the L0 loss.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_L0[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    # feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 50/50 [00:11<00:00,  4.50it/s]


In [None]:
torch.tensor(all_L0).mean()

tensor(91.9437)

In [None]:
clear_cache()

#### 4.2.2 Substitution Loss

In [None]:
from tqdm import tqdm
sae_reconstruction_metric = R2Score().to(device)

all_SL_clean = []
all_SL_reconstructed = []

total_batches = get_batch_size(Experiment.SUBSTITUTION_LOSS)
all_tokens_SL = get_tokens_sample(Experiment.SUBSTITUTION_LOSS)

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the losses.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_SL[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, finetune_model, sae, sae_id, sae_reconstruction_metric)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

100%|██████████| 25/25 [00:41<00:00,  1.67s/it]


In [None]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(1.935546875, 3.013671875)

In [None]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


0.6007424592971802

In [None]:
loss_reconstructed_tensor = torch.tensor(all_SL_reconstructed)
loss_reconstructed_tensor.sort()

torch.return_types.sort(
values=tensor([0.6885, 1.6035, 2.0840, 2.4258, 2.5215, 2.5312, 2.5898, 2.6113, 2.6250,
        2.7148, 3.0801, 3.0938, 3.1211, 3.2070, 3.3574, 3.4609, 3.5410, 3.6094,
        3.6133, 3.7832, 3.7949, 3.8086, 3.8203, 3.8262, 3.8301],
       dtype=torch.float16),
indices=tensor([12, 11, 20, 19,  0, 24,  6, 23, 10,  8, 22, 13, 21, 18,  4, 15, 16,  3,
        14,  9,  5,  2, 17,  1,  7]))

In [None]:
# Filter out NaN values (if there are any)
filtered_loss_reconstructed = loss_reconstructed_tensor[~torch.isinf(loss_reconstructed_tensor)]
print(f'Filtered substitution loss = {filtered_loss_reconstructed.mean().item()}')

Filtered substitution loss = 3.013671875


#### 4.2.3 Feature activations histogram

In [None]:
from tqdm import tqdm

all_feature_acts = []

total_batches = get_batch_size(Experiment.FEATURE_ACTS)
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_ACTS)

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the histogram
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id]
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 25/25 [00:11<00:00,  2.18it/s]


In [None]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

#### 4.2.4 Feature density histogram

In [None]:
all_histogram_tokens = get_tokens_sample(Experiment.FEATURE_DENSITY)
total_batches = get_batch_size(Experiment.FEATURE_DENSITY)

total_tokens = total_batches * batch_size_tokens
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_batches)):
    # Use the same sample to calculate the histogram
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                             names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    activations_original = cache[sae_id].flatten(0, 1).float()
    # activations_filtered = filter_activations(activations_original)

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(activations_original) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.to('cpu')

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del activations_original
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 50/50 [00:30<00:00,  1.65it/s]


In [None]:
density_plotter.plot(y_scalar=1.5, y_scale_bin=-1)

In [None]:
# Save the computed feature densities
finetune_feature_densities = density_plotter.feature_densities

# Choose saving names consistent with saetuning/get_scores.py
finetune_feature_densities_fname = f'Feature_densities_{saving_name_ft}_on_{saving_name_ds}.pt'

if IN_COLAB:
    saving_path = f'./{finetune_feature_densities_fname}'
else:
    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    saving_path = datapath / finetune_feature_densities_fname

torch.save(finetune_feature_densities, saving_path)

In [None]:
total_tokens