### Setup

In [1]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens

    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    IN_COLAB = False

Collecting sae-lens
  Downloading sae_lens-3.22.2-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dot

In [2]:
# Standard imports
import os
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import LanguageModelSAERunnerConfig
from sae_lens import ActivationsStore
import os
from dotenv import load_dotenv
import typing
from dataclasses import dataclass
from tqdm import tqdm
import logging
import torch
import torch.nn.functional as F

# GPU memory saver (this script doesn't need gradients computation)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7972a73eba90>

### Config

In [3]:
# define the model to work with
MODEL = 'GPT2' # GEMMA, MISTRAL, GPT2

if MODEL == 'GEMMA':
    # Base model stuff
    BASE_MODEL = "google/gemma-2b"
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    BASE_TOKENIZER_NAME = BASE_MODEL

    # Finetuned model stuff
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gemma-2b-res-jb'
    hook_part = 'post'
    layer_num = 6
elif MODEL == 'MISTRAL':
    # Base model stuff
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    # Finetuned model stuff
    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B'
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    # SAE stuff
    RELEASE = 'mistral-7b-res-wg'
    hook_part = 'pre'
    layer_num = 8
elif MODEL == 'GPT2':
    # Base model stuff
    BASE_MODEL = "gpt2-small"
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = 'openai-community/gpt2'

    # Finetuned model stuff
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None

    # SAE stuff
    RELEASE = 'gpt2-small-res-jb'
    hook_part = 'pre'
    layer_num = 6

saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

### Loading the activations

Now we load our tensors from pre_3_compute_activations.ipynb:

In [4]:
from pathlib import Path
def get_base_and_data_paths(drive_folder='My Drive/sae_data'):
    if IN_COLAB:
      return Path('./'), Path(f'/content/drive/{drive_folder}')

    # Load environment variables from the .env file
    load_dotenv()
    # Access the PYTHONPATH variable
    pythonpath = Path(os.getenv('PYTHONPATH'))
    # Print to verify
    print(f"PYTHONPATH: {pythonpath}")
    datapath = pythonpath / 'data'
    print(f"DATAPATH: {datapath}")

    return pythonpath, datapath

_, datapath = get_base_and_data_paths()

import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

In [5]:
base_activations = torch.load(datapath / f"base_acts_{saving_name_base}_on_{saving_name_ds}.pt")
finetune_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")

base_activations.shape, finetune_activations.shape

  base_activations = torch.load(datapath / f"base_acts_{saving_name_base}_on_{saving_name_ds}.pt")
  finetune_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")


(torch.Size([250, 1024, 768]), torch.Size([250, 1024, 768]))

### Computing the similarities

In [6]:
#### Similarity and Distance Computations ####

# 1. Compute pairwise cosine similarity between base and finetune activations
def compute_cosine_similarity(base_activations, finetune_activations):
    # Normalize activations along the activation dimension
    base_norm = F.normalize(base_activations, dim=-1)
    finetune_norm = F.normalize(finetune_activations, dim=-1)

    # Compute dot product along activation dimension to get cosine similarity
    cosine_similarity = torch.einsum('bca,bca->bc', base_norm, finetune_norm)  # [N_BATCH, N_CONTEXT]
    return cosine_similarity

# 2. Compute pairwise Euclidean distance between base and finetune activations
def compute_euclidean_distance(base_activations, finetune_activations):
    # Compute squared difference and sum along activation dimension
    euclidean_distance = torch.norm(base_activations - finetune_activations, dim=-1)  # [N_BATCH, N_CONTEXT]
    return euclidean_distance

In [7]:
N_CONTEXT = base_activations.shape[1]

In [8]:
# Call the functions to compute cosine similarity and Euclidean distance
cosine_similarity = compute_cosine_similarity(base_activations, finetune_activations)
euclidean_distance = compute_euclidean_distance(base_activations, finetune_activations)

del base_activations, finetune_activations
clear_cache()

# Output tensor shapes
print(f"Cosine Similarity Tensor Shape: {cosine_similarity.shape}")  # [N_BATCH, N_CONTEXT]
print(f"Euclidean Distance Tensor Shape: {euclidean_distance.shape}")  # [N_BATCH, N_CONTEXT]

Cosine Similarity Tensor Shape: torch.Size([250, 1024])
Euclidean Distance Tensor Shape: torch.Size([250, 1024])


### Plotting the similiraties

In [9]:
import plotly.express as px
import plotly.graph_objects as go

In [10]:
# 1. Plot heatmap of similarity metric
def plot_similarity_heatmap(ST):
    fig = px.imshow(ST.cpu().numpy(), color_continuous_scale='Viridis',
                    labels={'x': 'Context (Tokens)', 'y': 'Batch Index'},
                    title="Similarity Heatmap (Batches vs Context)")
    fig.show()

# 2. Reduce ST across context dimension to [N_BATCH] and plot histogram across batches
def plot_batch_histogram(ST):
    batch_mean_similarity = torch.mean(ST, dim=1).cpu().numpy()  # Shape [N_BATCH]
    fig = px.histogram(batch_mean_similarity, nbins=30, labels={'value': 'Mean Similarity'},
                       title="Histogram of Mean Similarity Across Batches")
    fig.show()

# 3. Flatten ST into [N_BATCH * N_CONTEXT] and plot histogram across all tokens
def plot_token_histogram(ST):
    flattened_ST = ST.flatten().cpu().numpy()  # Shape [N_BATCH * N_CONTEXT]
    fig = px.histogram(flattened_ST, nbins=100, labels={'value': 'Similarity'},
                       title="Histogram of Similarity Across All Tokens")
    fig.show()

# 4. Reduce ST across batch dimension to [N_CONTEXT] and plot line plot across context
def plot_context_line(ST):
    context_mean_similarity = torch.mean(ST, dim=0).cpu().numpy()  # Shape [N_CONTEXT]
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(N_CONTEXT)), y=context_mean_similarity,
                             mode='lines', name='Context Mean Similarity'))
    fig.update_layout(title="Line Plot of Mean Similarity Across Context",
                      xaxis_title="Context (Tokens)", yaxis_title="Mean Similarity")
    fig.show()

# 5. Report the global mean value of the similarity metric
def report_global_mean(ST):
    global_mean_similarity = torch.mean(ST).item()
    print(f"Global mean similarity: {global_mean_similarity}")



In [11]:
ST1 = cosine_similarity
ST2 = euclidean_distance

In [12]:
# Call the functions in sequence
plot_similarity_heatmap(ST1)

In [13]:
plot_batch_histogram(ST1)

In [14]:
plot_token_histogram(ST1)

In [15]:
plot_context_line(ST1)

In [16]:
report_global_mean(ST1)

Global mean similarity: 0.48193359375


In [17]:
plot_similarity_heatmap(ST2)

In [18]:
plot_batch_histogram(ST2)

In [19]:
plot_token_histogram(ST2)

In [20]:
plot_context_line(ST2)

In [21]:
report_global_mean(ST2)

Global mean similarity: 80.375
