In [1]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [2]:
%pip install sae-lens transformer-lens



In [3]:
# define the model to work with
GEMMA = True

if GEMMA == True:
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
else:
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"

layer_num = 6
hook_part = "pre"
TOTAL_BATCHES = 50


### Task 4.1 Pretrained case

In [4]:
import torch
import torch.nn.functional as F
from enum import Enum
import numpy as np
from scipy.stats import gamma
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

# Access the PYTHONPATH variable
PYTHONPATH = os.getenv('PYTHONPATH')
DATAPATH = PYTHONPATH + '/data'


#### Enum for pretty code ####
class AggregationType(Enum):
    MEAN = 'mean'
    LAST= 'last'

class SimilarityMetric(Enum):
  COSINE = 'cosine'
  EUCLIDEAN = 'euclidean'


#### Similarity and Distance Computations ####

# 1. Compute pairwise cosine similarity between base and finetune activations
def compute_cosine_similarity(base_activations, finetune_activations):
    # Normalize activations along the activation dimension
    base_norm = F.normalize(base_activations, dim=-1)
    finetune_norm = F.normalize(finetune_activations, dim=-1)

    # Compute dot product along activation dimension to get cosine similarity
    cosine_similarity = torch.einsum('bca,bca->bc', base_norm, finetune_norm)  # [N_BATCH, N_CONTEXT]
    return cosine_similarity

# 2. Compute pairwise Euclidean distance between base and finetune activations
def compute_euclidean_distance(base_activations, finetune_activations):
    # Compute squared difference and sum along activation dimension
    euclidean_distance = torch.norm(base_activations - finetune_activations, dim=-1)  # [N_BATCH, N_CONTEXT]
    return euclidean_distance

#### CKA code ####

def linear_kernel(X, Y):
  """
  Compute the linear kernel (dot product) between matrices X and Y.
  """
  return torch.mm(X, Y.T)

def HSIC(K, L):
    """
    Calculate the Hilbert-Schmidt Independence Criterion (HSIC) between kernels K and L.
    """
    n = K.shape[0]  # Number of samples
    H = torch.eye(n) - (1./n) * torch.ones((n, n))

    KH = torch.mm(K, H)
    LH = torch.mm(L, H)
    return 1./((n-1)**2) * torch.trace(torch.mm(KH, LH))

def CKA(X, Y):
    """
    Calculate the Centered Kernel Alignment (CKA) for matrices X and Y.
    If no kernel is specified, the linear kernel will be used by default.
    """

    # Compute the kernel matrices for X and Y
    K = linear_kernel(X, X)
    L = linear_kernel(Y, Y)

    # Calculate HSIC values
    hsic = HSIC(K, L)
    varK = torch.sqrt(HSIC(K, K))
    varL = torch.sqrt(HSIC(L, L))

    # Return the CKA value
    return hsic / (varK * varL)

#### Quantitave SAE evaluation ####
def L0_loss(x, threshold=1e-8):
    """
    Expects a tensor x of shape [N_TOKENS, N_SAE].

    Returns a scalar representing the mean value of activated features (i.e. values across the N_SAE dimensions bigger than
    the threshhold), a.k.a. L0 loss.
    """
    return (x > threshold).float().sum(-1).mean()

import plotly.graph_objs as go

def plot_density(feature_acts, num_bins=100):
    """
    Expects a tensor feature_acts of shape [N_TOKENS, N_SAE].

    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    (see https://transformer-circuits.pub/2023/monosemantic-features#appendix-feature-density)
    """
    # Flatten the tensor
    feature_acts_flat = torch.flatten(feature_acts)

    # Compute the logarithmic transformation using PyTorch
    log_feature_acts_flat = torch.log10(torch.abs(feature_acts_flat) + 1e-10).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_feature_acts_flat).item()
    hist_max = torch.max(log_feature_acts_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_feature_acts_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot
    fig.update_layout(
        title="SAE Feature Density Diagram",
        xaxis_title="Log10 of Feature Activations",
        yaxis_title="Density",
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Show the plot
    fig.show()

from transformer_lens import HookedTransformer
from functools import partial

import torch
from functools import partial

def get_substitution_loss(tokens, model, sae, sae_layer):
    '''
    Expects a tensor of input tokens of shape [N_BATCHES, N_CONTEXT].

    Returns two losses:
    1. Clean loss - loss of the normal forward pass of the model at the input tokens
    2. Substitution loss - loss when substituting SAE reconstructions of the residual stream at the SAE layer of the model
    '''
    batch_size, seq_len = tokens.shape

    # Run the model with cache to get the original activations and clean loss
    loss_clean, cache = model.run_with_cache(tokens, names_filter=[sae_layer], return_type="loss")

    # Fetch and detach the original activations to avoid any unnecessary graph retention
    original_activations = cache[sae_layer].detach()

    # Get the SAE reconstructed activations (forward pass through SAE)
    post_reconstructed = sae.forward(original_activations).detach()

    # Clear the cache and unused variables early
    del original_activations, cache
    torch.cuda.empty_cache()

    # Hook function to substitute activations in-place
    def hook_function(activations, hook, new_activations):
        activations.copy_(new_activations)  # In-place copy to save memory
        return activations

    # Run model again with hooks to substitute activations and get the substitution loss
    loss_reconstructed = model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(sae_layer, partial(hook_function, new_activations=post_reconstructed))]
    )

    # Clean up reconstructed activations and free up memory
    del post_reconstructed
    torch.cuda.empty_cache()

    return loss_clean, loss_reconstructed


In [None]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [None]:
# import the required libraries
from sae_lens import SAE

RELEASE = 'gemma-2b-res-jb'
sae_id = 'blocks.6.hook_resid_post'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [None]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

'relu'

In [10]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

NameError: name 'base_model' is not defined

#### 4.1.1 L0 loss

In [None]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_L0 = []  # This will store the tokens for reuse
all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_L0.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)

    feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
all_tokens_L0 = torch.cat(all_tokens_L0)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [01:15<00:00,  1.50s/it]


In [None]:
torch.tensor(all_L0).mean()

tensor(53.6718)

#### 4.1.2 Substitution Loss

In [None]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_SL = []  # This will store the tokens for reuse
all_SL_clean = []
all_SL_reconstructed = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_SL.append(tokens)

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, base_model, sae, sae_id)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens_SL)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [03:16<00:00,  3.93s/it]


In [None]:
torch.tensor(all_SL_clean).mean(), torch.tensor(all_SL_reconstructed).mean()

(tensor(2.6797, dtype=torch.float16), tensor(3.1758, dtype=torch.float16))

### Task 4.1 FineTuned case

In [6]:
# Load the finetune model and its tokenizer
finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [8]:
del finetune_model_hf
clear_cache()

In [12]:
# import the required libraries
from sae_lens import SAE

RELEASE = 'gemma-2b-res-jb'
layer_num = 6
sae_id = f'blocks.{layer_num}.hook_resid_post'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

(…)id_post_16384_anthropic_fast_lr/cfg.json:   0%|          | 0.00/2.18k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/65.6k [00:00<?, ?B/s]

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [15]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=finetune_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]



#### 4.2.1 L0 loss

In [16]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_L0 = []  # This will store the tokens for reuse
all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_L0.append(tokens)

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)

    feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
all_tokens_L0 = torch.cat(all_tokens_L0)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [01:00<00:00,  1.20s/it]


In [17]:
torch.tensor(all_L0).mean()

tensor(84.3003)

In [18]:
clear_cache()

#### 4.2.2 Substitution Loss

In [19]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_SL = []  # This will store the tokens for reuse
all_SL_clean = []
all_SL_reconstructed = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_SL.append(tokens)

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, finetune_model, sae, sae_id)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens_SL)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 50/50 [03:16<00:00,  3.93s/it]


In [20]:
torch.tensor(all_SL_clean).mean(), torch.tensor(all_SL_reconstructed).mean()

(tensor(3.3379, dtype=torch.float16), tensor(inf, dtype=torch.float16))