In [None]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens
except ImportError:
    IN_COLAB = False

Collecting sae-lens
  Downloading sae_lens-3.21.1-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.6.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dot

In [None]:
# Standard imports
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


# Config

In [None]:
# define the model to work with
MODEL = 'MISTRAL' # GEMMA, GPT2

if MODEL == 'GEMMA':
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    FINETUNE_PATH = None
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'post'
    layer_num = 6
elif MODEL == 'GPT2':
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'pre'
    layer_num = 6
elif MODEL == 'MISTRAL':
    RELEASE = 'mistral-7b-res-wg'
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B' #DeepMount00/Mistral-Ita-7b
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    hook_part = 'pre'
    layer_num = 8

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

In [None]:
import plotly.graph_objs as go
from functools import partial

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """
    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    scale_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=f"SAE Features {y_value} histogram ({first_bin_name}: {first_bin_value:.2e})",
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Density",
        yaxis_range=[0, scale_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

class FeatureDensityPlotter:
    def __init__(self, n_features=None, n_tokens=None, activation_threshold=1e-10, num_bins=100,
                 feature_densities=None):
        if feature_densities is not None:
            self.feature_densities = feature_densities
            return

        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon)

# Feature densities loading & plotting

## Base model feature densities

In [None]:
from pathlib import Path

# Choose saving names consistent with saetuning/get_scores.py
saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

base_feature_densities_fname = f'Feature_densities_{saving_name_base}_on_{saving_name_ds}.pt'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')
    datapath = Path('/content/drive/My Drive/sae_data')

    # Define the path to your JSON file in Google Drive
    load_path = datapath / base_feature_densities_fname
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    load_path = datapath / base_feature_densities_fname

base_feature_densities = torch.load(load_path)

Mounted at /content/drive


  base_feature_densities = torch.load(load_path)


In [None]:
n_features = base_feature_densities.numel()
n_features

65536

In [None]:
density_plotter = FeatureDensityPlotter(feature_densities=base_feature_densities)
density_plotter.plot()

## Finetune model feature densities

In [None]:
finetune_feature_densities_fname = f'Feature_densities_{saving_name_ft}_on_{saving_name_ds}.pt'

if IN_COLAB:
    # If in Colab, mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    COLAB_BASE_PATH = '/content/drive/My Drive/sae_data'

    # Define the path to your JSON file in Google Drive
    load_path = os.path.join(COLAB_BASE_PATH, finetune_feature_densities_fname)
else:
    # If not in Colab, use local folder
    # Assuming this is being run from the 'notebooks' folder
    sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

    from saetuning.utils import get_env_var
    _, datapath = get_env_var()

    load_path = datapath / finetune_feature_densities_fname

finetune_feature_densities = torch.load(load_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [None]:
density_plotter = FeatureDensityPlotter(feature_densities=finetune_feature_densities)
density_plotter.plot()

In [None]:
log_epsilon=1e-10

base_feature_densities = torch.log10(base_feature_densities + log_epsilon)
finetune_feature_densities = torch.log10(finetune_feature_densities + log_epsilon)

In [None]:
import plotly.graph_objs as go
from scipy.stats import linregress

# Convert tensors to NumPy arrays for compatibility with other libraries
base_feature_densities_np = base_feature_densities.cpu().numpy()
finetune_feature_densities_np = finetune_feature_densities.cpu().numpy()

# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(base_feature_densities_np, finetune_feature_densities_np)

# Define the regression line
regression_line = slope * base_feature_densities_np + intercept

# Create scatter plot
scatter_trace = go.Scatter(
    x=base_feature_densities_np,
    y=finetune_feature_densities_np,
    mode='markers',
    name='Data points',
    marker=dict(size=5, opacity=0.7)
)

# Create the regression line plot
line_trace = go.Scatter(
    x=base_feature_densities_np,
    y=regression_line,
    mode='lines',
    name=f'Regression line (R = {r_value:.2f})',
    line=dict(color='red')
)

# Set up the layout
layout = go.Layout(
    title='Scatter Plot of SAE Features Densities',
    xaxis=dict(title='Base Model SAE Densities'),
    yaxis=dict(title='Finetuned Model SAE Densities'),
    showlegend=True
)

# Combine the traces into a figure
fig = go.Figure(data=[scatter_trace, line_trace], layout=layout)

# Show the plot
fig.show()

# Print correlation coefficient
print(f"Correlation coefficient (R): {r_value:.4f}")

Correlation coefficient (R): 0.8622


In [None]:
base_feature_densities.mean(), finetune_feature_densities.mean()

(tensor(-3.8565), tensor(-3.7167))

In [None]:
df = pd.DataFrame.from_dict({
    'base_feature_densities': base_feature_densities_np,
    'finetune_feature_densities': finetune_feature_densities_np
})
df

Unnamed: 0,base_feature_densities,finetune_feature_densities
0,-3.630089,-3.630089
1,-3.065817,-3.000000
2,-2.823909,-2.638863
3,-3.660052,-3.527426
4,-3.602060,-3.575731
...,...,...
65531,-10.000000,-10.000000
65532,-3.550907,-3.575731
65533,-2.593992,-2.578293
65534,-3.630089,-3.602060


In [None]:
import plotly.express as px

fig = px.parallel_coordinates(df,
                              dimensions=['base_feature_densities', 'finetune_feature_densities']
                             )
fig.show()

# Sampling features from density intervals

In [None]:
import torch

def subsample_indices(tensor, total_samples=10, low_threshold=-8, low_percentage=0.2, high_percentage=0.8,
                      log=False):
    """
    Subsample indices from the tensor based on the log10 scale density.

    Parameters:
    - tensor (torch.Tensor): The input tensor of feature densities.
    - total_samples (int): The total number of samples to return.
    - low_threshold (float): The log10 threshold for the lowest bar (default: -9.5, approximates -10).
    - low_percentage (float): The percentage of total_samples to take from the lowest bar (default: 0.2).
    - high_percentage (float): The percentage of total_samples to take from the high density interval (default: 0.8).

    Returns:
    - combined_indices (torch.Tensor): A tensor containing the subsampled indices.
    """
    if log:
        # Convert tensor to log10 scale
        tensor = torch.log10(tensor)

    # 1. Subsample low_percentage of samples from the lowest bar
    lowest_bar_indices = (tensor <= low_threshold)
    num_lowest_samples = int(total_samples * low_percentage)
    lowest_bar_sample_indices = torch.nonzero(lowest_bar_indices).squeeze(1)

    if len(lowest_bar_sample_indices) > num_lowest_samples:
        lowest_bar_sample_indices = lowest_bar_sample_indices[torch.randperm(len(lowest_bar_sample_indices))[:num_lowest_samples]]

    # 2. Subsample high_percentage of samples from the interval [-5, -1] in log scale
    high_density_indices = (tensor >= -5) & (tensor <= -1)
    num_high_density_samples = int(total_samples * high_percentage)
    high_density_sample_indices = torch.nonzero(high_density_indices).squeeze(1)

    if len(high_density_sample_indices) > num_high_density_samples:
        high_density_sample_indices = high_density_sample_indices[torch.randperm(len(high_density_sample_indices))[:num_high_density_samples]]

    return lowest_bar_sample_indices, high_density_sample_indices

In [None]:
dead_base_features, dense_base_features = subsample_indices(base_feature_densities)
dead_base_features, dense_base_features

torch.save(dead_base_features, datapath / f'{saving_name_base}_dead_features.pt')
torch.save(dense_base_features, datapath / f'{saving_name_base}_dense_features.pt')

(tensor([11108, 25007]),
 tensor([ 8696, 39413, 20898, 19819, 32480, 13560, 14911,  6419]))

In [None]:
base_feature_densities[dead_base_features], base_feature_densities[dense_base_features]

(tensor([-10., -10.]),
 tensor([-2.8150, -3.8062, -4.8062, -3.8062, -2.5184, -3.3912, -4.5051, -2.4923]))

# Interpreting features from different density intervals

In [None]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)

mistral_7b_layer_8/cfg.json:   0%|          | 0.00/430 [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

## Base model

In [None]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]



Loaded pretrained model mistral-7b into HookedTransformer


In [None]:
from sae_lens import ActivationsStore

batch_size_prompts = 5

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=batch_size_prompts,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_tokens = activation_store.context_size * batch_size_prompts
batch_size_tokens

Downloading readme:   0%|          | 0.00/776 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



1280

### Feature activation vectors

In [None]:
def get_feature_activations(features_ids, model, sae=sae, total_batches=50, activation_store=None,
                            batch_size_prompts=batch_size_prompts, sae_id=sae_id, layer_num=layer_num,
                            model_name=BASE_MODEL, exclude_bos=True, tokens_sample=None, return_dtype=torch.float16):
    base_model_run = tokens_sample is None

    if base_model_run:
        assert activation_store is not None
        def get_tokens(k):
            """Returns the tokens sampled from the activation store"""
            # Get the corresponding batch of tokens from all_tokens
            tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
            if exclude_bos and activation_store.prepend_bos:
                tokens = tokens[:, 1:]

            return tokens
    else:
        def get_tokens(k):
            """Returns the tokens for the k-th outer batch, where 0 <= k < TOTAL_BATCHES"""

            # Get the corresponding batch of tokens from all_tokens
            tokens = tokens_sample[k]  # [N_BATCH, N_CONTEXT]
            return tokens

    if base_model_run:
        all_tokens = []
    all_feature_activations = []

    for k in tqdm(range(total_batches)):
        # Get a batch of tokens from the dataset
        tokens = get_tokens(k)  # [N_BATCH, N_CONTEXT]

        # Run the model and store the activations
        _, cache = model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                        names_filter=[sae_id]) # [N_BATCH, N_CONTEXT, D_MODEL]

        # Get the activations from the cache at the sae_id
        sae_in = cache[sae_id].flatten(0, 1) # [N_BATCH * N_CONTEXT, D_MODEL]

        del cache
        clear_cache()

        # Store tokens for later reuse
        if base_model_run:
            all_tokens.append(tokens)

        # Encode the activations with the SAE
        sae_hidden = sae.encode(sae_in) # [N_BATCH * N_CONTEXT, N_HIDDEN]

        # Select only given features
        feature_activations = sae_hidden[:, features_ids] # [N_BATCH * N_CONTEXT, len(feature_ids)]
        all_feature_activations.append(feature_activations)

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del sae_in, sae_hidden
        clear_cache()

    tokens_dataset = torch.stack(all_tokens, dim=0) if base_model_run else None
    all_feature_activations = torch.cat(all_feature_activations, dim=0).to(return_dtype)

    return all_feature_activations, tokens_dataset

In [None]:
base_model_dead_act, base_model_dead_act_tokens = get_feature_activations(dead_base_features, base_model,
                                                                          activation_store=activation_store)
base_model_dead_act.shape, base_model_dead_act_tokens.shape

100%|██████████| 50/50 [00:27<00:00,  1.82it/s]


(torch.Size([64000, 2]), torch.Size([50, 5, 256]))

In [None]:
base_model_dense_act, base_model_dense_act_tokens = get_feature_activations(dense_base_features, base_model,
                                                                            activation_store=activation_store)
base_model_dense_act.shape, base_model_dense_act_tokens.shape

100%|██████████| 50/50 [00:26<00:00,  1.89it/s]


(torch.Size([64000, 8]), torch.Size([50, 5, 256]))

In [None]:
torch.save(base_model_dead_act, datapath / f'{saving_name_base}_dead_act.pt')
torch.save(base_model_dense_act, datapath / f'{saving_name_base}_dense_act.pt')

torch.save(base_model_dead_act_tokens, datapath / f'{saving_name_base}_dead_act_tokens.pt')
torch.save(base_model_dense_act_tokens, datapath / f'{saving_name_base}_dense_act_tokens.pt')

### Feature logit vectors

In [None]:
def get_feature_vector(feature_id, sae=sae):
    return sae.W_dec[feature_id]

def get_feature_logits(feature_id, model, sae=sae):
    W_U = model.W_U.to(torch.float32)
    feature_vector = get_feature_vector(feature_id, sae)

    logits = feature_vector @ W_U
    return logits

In [None]:
def get_features_logit_vectors(features_ids, model, sae=sae):
    """
    Assumes tokens of shape [total_batches, batch*seq]
    """
    logit_vectors = []

    for feature_id in features_ids:
        feature_logits = get_feature_logits(feature_id, model, sae=sae)
        logit_vectors.append(feature_logits)

    return torch.stack(logit_vectors).T

In [None]:
dense_base_features_logit_vectors = get_features_logit_vectors(dense_base_features, base_model)
torch.save(dense_base_features_logit_vectors, datapath / f'{saving_name_base}_dense_logit_vectors.pt')

dense_base_features_logit_vectors.shape

torch.Size([32000, 8])

In [None]:
base_model_unembed = base_model.W_U.copy()

del base_model, activation_store
clear_cache()

## Finetune model

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def adjust_state_dict(model, base_model_vocab_size):
    """Adjust the state_dict of the model to match the base model's vocab size."""
    state_dict = model.state_dict()

    # Adjust the embedding matrix
    if state_dict['model.embed_tokens.weight'].shape[0] > base_model_vocab_size:
        state_dict['model.embed_tokens.weight'] = state_dict['model.embed_tokens.weight'][:base_model_vocab_size, :]

    # Adjust the unembedding (lm_head) matrix
    if state_dict['lm_head.weight'].shape[0] > base_model_vocab_size:
        state_dict['lm_head.weight'] = state_dict['lm_head.weight'][:base_model_vocab_size, :]

    return state_dict

def load_hf_model(path, base_model=BASE_MODEL, device='cuda', dtype=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)

    # Adjust the model's state dict to match the base model's vocab size
    base_model_vocab_size = 32000  # Get base model vocab size
    adjusted_state_dict = adjust_state_dict(model, base_model_vocab_size)

    # Adjust model architecture to match the new vocab size
    model.resize_token_embeddings(base_model_vocab_size)

    # Load the adjusted state dict back into the model
    model.load_state_dict(adjusted_state_dict, strict=False)

    # Now load the fine-tuned model into the HookedSAETransformer
    finetune_model = HookedSAETransformer.from_pretrained(
        base_model, device=device, hf_model=model, dtype=dtype
    )

    del model  # offload the HF model as it's already wrapped into HookedSAETransformer (finetune_model)
    clear_cache()

    return tokenizer, finetune_model

In [None]:
# Load the finetune model and its tokenizer
finetune_tokenizer, finetune_model = load_hf_model(FINETUNE_PATH, device=device, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]



Loaded pretrained model mistral-7b into HookedTransformer


In [None]:
PRELOAD_TOKENS = False
if PRELOAD_TOKENS:
  base_model_dead_act_tokens = torch.load(datapath / f'{saving_name_base}_dead_act_tokens.pt')
  base_model_dense_act_tokens = torch.load(datapath / f'{saving_name_base}_dense_act_tokens.pt')

base_model_dense_act_tokens.shape, base_model_dead_act_tokens.shape

(torch.Size([50, 5, 256]), torch.Size([50, 5, 256]))

In [None]:
finetune_model_dense_act, _ = get_feature_activations(dense_base_features, finetune_model,
                                                      tokens_sample=base_model_dense_act_tokens)
finetune_model_dead_act, _ = get_feature_activations(dead_base_features, finetune_model,
                                                     tokens_sample=base_model_dead_act_tokens)

100%|██████████| 50/50 [00:26<00:00,  1.92it/s]
100%|██████████| 50/50 [00:26<00:00,  1.88it/s]


In [None]:
torch.save(finetune_model_dense_act, datapath / f'{saving_name_ft}_dense_act.pt')
torch.save(finetune_model_dead_act, datapath / f'{saving_name_ft}_dead_act.pt')

In [None]:
dense_finetune_features_logit_vectors = get_features_logit_vectors(dense_base_features, finetune_model)
torch.save(dense_finetune_features_logit_vectors, datapath / f'{saving_name_ft}_dense_logit_vectors.pt')

dense_finetune_features_logit_vectors.shape

torch.Size([32000, 8])

In [None]:
finetune_model_unembed = finetune_model.W_U.copy()

del finetune_model
clear_cache()

In [None]:
# TODO: check the norm ratios & L0 distance between the unembed matrices

# Ploting & reporting the feature similarities

In [None]:
PRELOAD_FEATURES = True

if PRELOAD_FEATURES:
  # Define the path to your JSON file in Google Drive
  base_model_dense_path = datapath / f'{saving_name_base}_dense_act.pt'
  base_model_dead_path = datapath / f'{saving_name_base}_dead_act.pt'

  finetune_model_dead_path = datapath / f'{saving_name_ft}_dead_act.pt'
  finetune_model_dense_path = datapath / f'{saving_name_ft}_dense_act.pt'

  dense_base_features_logit_vectors_path = datapath / f'{saving_name_base}_dense_logit_vectors.pt'
  dense_finetune_features_logit_vectors_path = datapath / f'{saving_name_ft}_dense_logit_vectors.pt'

  dead_base_features_path = datapath / f'{saving_name_base}_dead_features.pt'
  dense_base_features_path = datapath / f'{saving_name_base}_dense_features.pt'
  ### RETRIEVE THE LOST TENSORS
  base_model_dense_act = torch.load(base_model_dense_path)
  base_model_dead_act = torch.load(base_model_dead_path)

  finetune_model_dead_act = torch.load(finetune_model_dead_path)
  finetune_model_dense_act = torch.load(finetune_model_dense_path)

  dense_base_features_logit_vectors = torch.load(dense_base_features_logit_vectors_path)
  dense_finetune_features_logit_vectors = torch.load(dense_finetune_features_logit_vectors_path)

  dead_base_features = torch.load(dead_base_features_path)
  dense_base_features = torch.load(dense_base_features_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is poss

In [None]:
assert finetune_model_dense_act.shape == base_model_dense_act.shape
assert finetune_model_dead_act.shape == base_model_dead_act.shape

finetune_model_dense_act.shape, finetune_model_dead_act.shape

(torch.Size([64000, 8]), torch.Size([64000, 2]))

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def make_similarity_plots(x, y, feature_ids, plot_rows, plot_cols, similarity_score='Activation'):
    n_plots = x.shape[1]
    assert n_plots == plot_rows * plot_cols

    x = x.cpu().numpy()
    y = y.cpu().numpy()

    activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_plots)]

    # Set up subplot grid
    fig = make_subplots(rows=plot_rows, cols=plot_cols,
                        x_title=f'Base Feature {similarity_score}',
                        y_title=f'Finetune Feature {similarity_score}',
                        subplot_titles=[f'Activation similarity = {similarity}' for similarity in activations_similarity])

    # Generate scatter plots for each column
    for i in range(n_plots):
        # Compute correlation coefficient
        corr_coef = np.corrcoef(x[:, i], y[:, i])[0, 1]

        # Add scatter plot
        scatter = go.Scatter(
            x=x[:, i],
            y=y[:, i],
            mode='markers',
            marker=dict(color='red', opacity=0.5),
            name=f'Feature {feature_ids[i]} (Corr={corr_coef:.2f})'
        )

        # Place scatter plot in correct subplot
        fig.add_trace(scatter, row=(i//plot_cols)+1, col=(i%plot_cols)+1)

    # Update layout for better aesthetics
    fig.update_layout(
        title="Feature Activations Scatter Plots and Correlations",
        height=600,
        width=1000,
        showlegend=False
    )

    # Show the figure
    fig.show()

def make_single_similarity_plot(x, y, feature_id, features_family, similarity_score='Activation'):
    # Convert tensors to numpy arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Compute correlation coefficient for the given feature (column index)
    corr_coef = np.corrcoef(x[:, feature_id], y[:, feature_id])[0, 1]

    # Set up scatter plot for the specific feature
    scatter = go.Scatter(
        x=x[:, feature_id],
        y=y[:, feature_id],
        mode='markers',
        marker=dict(color='red', opacity=0.5),
        name=f'Feature {feature_id} (Corr={corr_coef:.2f})'
    )

    # Set up layout with axis labels and title
    layout = go.Layout(
        title=f'Feature Activations Scatter Plot (Feature {features_family[feature_id]})<br>Activation similarity = {corr_coef:.2f}',
        xaxis=dict(title=f'Base Feature {similarity_score}'),
        yaxis=dict(title=f'Finetune Feature {similarity_score}'),
        height=600,
        width=600,
        showlegend=False
    )

    # Create figure and add trace
    fig = go.Figure(data=[scatter], layout=layout)

    # Show the figure
    fig.show()

def show_features_similarities(x, y, features_family, similarity_score='Activation'):
    # Ensure correct number of plots
    n_features = x.shape[1]

    # Convert tensors to numpy arrays
    x = x.cpu().numpy()
    y = y.cpu().numpy()

    # Compute activation similarities for each feature
    activations_similarity = [np.corrcoef(x[:, i], y[:, i])[0, 1] for i in range(n_features)]

    # Print activation similarities for each feature
    for i, similarity in enumerate(activations_similarity):
        print(f'Feature {features_family[i]} {similarity_score} similarity = {similarity:.2f}')

    mean_similarity = np.mean(activations_similarity)
    print(f'Mean {similarity_score} similarity = {mean_similarity:.2f}')

## Dead features

In [None]:
show_features_similarities(base_model_dead_act, finetune_model_dead_act, dead_base_features)

Feature 11108 Activation similarity = nan
Feature 25007 Activation similarity = nan



invalid value encountered in divide



In [None]:
# make_similarity_plots(base_model_dead_act, finetune_model_dead_act, dead_base_features,
#                       plot_rows=1, plot_cols=2)

## Dense features

In [None]:
show_features_similarities(base_model_dense_act, finetune_model_dense_act, dense_base_features)

Feature 8696 Activation similarity = 0.99
Feature 39413 Activation similarity = 1.00
Feature 20898 Activation similarity = 0.96
Feature 19819 Activation similarity = 1.00
Feature 32480 Activation similarity = 0.96
Feature 13560 Activation similarity = 0.89
Feature 14911 Activation similarity = 0.83
Feature 6419 Activation similarity = 1.00


In [None]:
show_features_similarities(dense_base_features_logit_vectors, dense_finetune_features_logit_vectors, dense_base_features, 'Logits')

Feature 8696 Logits similarity = 1.00
Feature 39413 Logits similarity = 1.00
Feature 20898 Logits similarity = 1.00
Feature 19819 Logits similarity = 1.00
Feature 32480 Logits similarity = 1.00
Feature 13560 Logits similarity = 1.00
Feature 14911 Logits similarity = 1.00
Feature 6419 Logits similarity = 1.00


In [None]:
n_dense = dense_base_features.numel()
n_dense

8

In [None]:
make_single_similarity_plot(base_model_dense_act, finetune_model_dense_act, 6, dense_base_features)