In [None]:
# define the model to work with
MODEL = 'MISTRAL' # GEMMA, GPT2

if MODEL == 'GEMMA':
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    FINETUNE_PATH = None
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'post'
    layer_num = 6
elif MODEL == 'GPT2':
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'pre'
    layer_num = 6
elif MODEL == 'MISTRAL':
    RELEASE = 'mistral-7b-res-wg'
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B' #DeepMount00/Mistral-Ita-7b
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    hook_part = 'pre'
    layer_num = 8

saving_name_base = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
saving_name_ft = FINETUNE_MODEL if "/" not in FINETUNE_MODEL else FINETUNE_MODEL.split("/")[-1]
saving_name_ds = DATASET_NAME if "/" not in DATASET_NAME else DATASET_NAME.split("/")[-1]

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

In [None]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens

    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    IN_COLAB = False

Mounted at /content/drive


In [None]:
# Standard imports
import os
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import LanguageModelSAERunnerConfig
from sae_lens import ActivationsStore
import os
from dotenv import load_dotenv
import typing
from dataclasses import dataclass
from tqdm import tqdm
import logging
import torch
import torch.nn.functional as F

# GPU memory saver (this script doesn't need gradients computation)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7d0ea43292a0>

Now we load our tensors from task 1+2

In [None]:
from pathlib import Path
def get_env_var():
    if IN_COLAB:
      return Path('./'), Path('/content/drive/My Drive/sae_data')

    # Load environment variables from the .env file
    load_dotenv()
    # Access the PYTHONPATH variable
    pythonpath = Path(os.getenv('PYTHONPATH'))
    # Print to verify
    print(f"PYTHONPATH: {pythonpath}")
    datapath = pythonpath / 'data'
    print(f"DATAPATH: {datapath}")

    return pythonpath, datapath

_, datapath = get_env_var()

import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

### Compute Similarities

In [None]:
#### Similarity and Distance Computations ####

# 1. Compute pairwise cosine similarity between base and finetune activations
def compute_cosine_similarity(base_activations, finetune_activations):
    # Normalize activations along the activation dimension
    base_norm = F.normalize(base_activations, dim=-1)
    finetune_norm = F.normalize(finetune_activations, dim=-1)

    # Compute dot product along activation dimension to get cosine similarity
    cosine_similarity = torch.einsum('bca,bca->bc', base_norm, finetune_norm)  # [N_BATCH, N_CONTEXT]
    return cosine_similarity

# 2. Compute pairwise Euclidean distance between base and finetune activations
def compute_euclidean_distance(base_activations, finetune_activations):
    # Compute squared difference and sum along activation dimension
    euclidean_distance = torch.norm(base_activations - finetune_activations, dim=-1)  # [N_BATCH, N_CONTEXT]
    return euclidean_distance

In [None]:
base_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")
finetune_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")

base_activations.shape, finetune_activations.shape

  base_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")
  finetune_activations = torch.load(datapath / f"finetune_acts_{saving_name_ft}_on_{saving_name_ds}.pt")


(torch.Size([250, 1024, 4096]), torch.Size([250, 1024, 4096]))

In [None]:
N_CONTEXT = base_activations.shape[1]

In [None]:
# Call the functions to compute cosine similarity and Euclidean distance
cosine_similarity = compute_cosine_similarity(base_activations, finetune_activations)
euclidean_distance = compute_euclidean_distance(base_activations, finetune_activations)

del base_activations, finetune_activations
clear_cache()

# Output tensor shapes
print(f"Cosine Similarity Tensor Shape: {cosine_similarity.shape}")  # [N_BATCH, N_CONTEXT]
print(f"Euclidean Distance Tensor Shape: {euclidean_distance.shape}")  # [N_BATCH, N_CONTEXT]

Cosine Similarity Tensor Shape: torch.Size([250, 1024])
Euclidean Distance Tensor Shape: torch.Size([250, 1024])


### Quantitative Analysis of Similiraties

#### **Questions:**

What to do next with those similarity tensors, how do we reduce the `[N_BATCH, N_CONTEXT]` dimensions? We can just take the global mean, but maybe there are better options.

- **Global Mean**: One simple option is to take the global mean across both batch and context dimensions, giving a single scalar representing the overall similarity.
- **Per-context Mean**: You could compute the mean across batches but preserve the context dimension, giving a tensor of shape `[N_CONTEXT]`, which reflects how similarity changes over the context length.
    - I don’t really expect that similarity should change with the token position in the context, but we could try it anyway
- **Per-batch Mean**: Similarly, you can compute the mean across contexts to focus on how different inputs (batches) are treated differently by the models, resulting in a tensor of shape `[N_BATCH]`
    - This makes sense only if each batch comes from a different context, which is not always the case (I think if the dataset is not pretokenized, the contexts are shuffled together).

In [None]:
import plotly.express as px
import plotly.graph_objects as go

In [None]:
# 1. Plot heatmap of similarity metric
def plot_similarity_heatmap(ST):
    fig = px.imshow(ST.cpu().numpy(), color_continuous_scale='Viridis',
                    labels={'x': 'Context (Tokens)', 'y': 'Batch Index'},
                    title="Similarity Heatmap (Batches vs Context)")
    fig.show()

# 2. Reduce ST across context dimension to [N_BATCH] and plot histogram across batches
def plot_batch_histogram(ST):
    batch_mean_similarity = torch.mean(ST, dim=1).cpu().numpy()  # Shape [N_BATCH]
    fig = px.histogram(batch_mean_similarity, nbins=30, labels={'value': 'Mean Similarity'},
                       title="Histogram of Mean Similarity Across Batches")
    fig.show()

# 3. Flatten ST into [N_BATCH * N_CONTEXT] and plot histogram across all tokens
def plot_token_histogram(ST):
    flattened_ST = ST.flatten().cpu().numpy()  # Shape [N_BATCH * N_CONTEXT]
    fig = px.histogram(flattened_ST, nbins=100, labels={'value': 'Similarity'},
                       title="Histogram of Similarity Across All Tokens")
    fig.show()

# 4. Reduce ST across batch dimension to [N_CONTEXT] and plot line plot across context
def plot_context_line(ST):
    context_mean_similarity = torch.mean(ST, dim=0).cpu().numpy()  # Shape [N_CONTEXT]
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(N_CONTEXT)), y=context_mean_similarity,
                             mode='lines', name='Context Mean Similarity'))
    fig.update_layout(title="Line Plot of Mean Similarity Across Context",
                      xaxis_title="Context (Tokens)", yaxis_title="Mean Similarity")
    fig.show()

# 5. Report the global mean value of the similarity metric
def report_global_mean(ST):
    global_mean_similarity = torch.mean(ST).item()
    print(f"Global mean similarity: {global_mean_similarity}")



In [None]:
ST1 = cosine_similarity
ST2 = euclidean_distance

In [None]:
# Call the functions in sequence
plot_similarity_heatmap(ST1)

In [None]:
plot_batch_histogram(ST1)

In [None]:
plot_token_histogram(ST1)

In [None]:
plot_context_line(ST1)

In [None]:
report_global_mean(ST1)

Global mean similarity: 1.0


In [None]:
plot_similarity_heatmap(ST2)

In [None]:
plot_batch_histogram(ST2)

In [None]:
plot_token_histogram(ST2)

In [None]:
plot_context_line(ST2)

### TODO:

- check the same with fine-tuning data for the model (not SAE)
- check the same with pre-train data for the base model (should be the same of the SAE training data)