In [16]:
%pip install sae-lens transformer-lens torcheval

Collecting torcheval
  Downloading torcheval-0.0.7-py3-none-any.whl.metadata (8.6 kB)
Downloading torcheval-0.0.7-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torcheval
Successfully installed torcheval-0.0.7


In [2]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [3]:
# define the model to work with
GEMMA = True

if GEMMA == True:
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    hook_part = 'post'
    layer_num = 6
else:
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"
    hook_part = 'pre'
    layer_num = 6

TOTAL_BATCHES = 25

In [4]:
# @title
import torch
import torch.nn.functional as F
from enum import Enum
import numpy as np
from scipy.stats import gamma
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

# Access the PYTHONPATH variable
PYTHONPATH = os.getenv('PYTHONPATH')
DATAPATH = PYTHONPATH + '/data'


#### Enum for pretty code ####
class AggregationType(Enum):
    MEAN = 'mean'
    LAST= 'last'

class SimilarityMetric(Enum):
  COSINE = 'cosine'
  EUCLIDEAN = 'euclidean'


#### Similarity and Distance Computations ####

# 1. Compute pairwise cosine similarity between base and finetune activations
def compute_cosine_similarity(base_activations, finetune_activations):
    # Normalize activations along the activation dimension
    base_norm = F.normalize(base_activations, dim=-1)
    finetune_norm = F.normalize(finetune_activations, dim=-1)

    # Compute dot product along activation dimension to get cosine similarity
    cosine_similarity = torch.einsum('bca,bca->bc', base_norm, finetune_norm)  # [N_BATCH, N_CONTEXT]
    return cosine_similarity

# 2. Compute pairwise Euclidean distance between base and finetune activations
def compute_euclidean_distance(base_activations, finetune_activations):
    # Compute squared difference and sum along activation dimension
    euclidean_distance = torch.norm(base_activations - finetune_activations, dim=-1)  # [N_BATCH, N_CONTEXT]
    return euclidean_distance

#### CKA code ####

def linear_kernel(X, Y):
  """
  Compute the linear kernel (dot product) between matrices X and Y.
  """
  return torch.mm(X, Y.T)

def HSIC(K, L):
    """
    Calculate the Hilbert-Schmidt Independence Criterion (HSIC) between kernels K and L.
    """
    n = K.shape[0]  # Number of samples
    H = torch.eye(n) - (1./n) * torch.ones((n, n))

    KH = torch.mm(K, H)
    LH = torch.mm(L, H)
    return 1./((n-1)**2) * torch.trace(torch.mm(KH, LH))

def CKA(X, Y):
    """
    Calculate the Centered Kernel Alignment (CKA) for matrices X and Y.
    If no kernel is specified, the linear kernel will be used by default.
    """

    # Compute the kernel matrices for X and Y
    K = linear_kernel(X, X)
    L = linear_kernel(Y, Y)

    # Calculate HSIC values
    hsic = HSIC(K, L)
    varK = torch.sqrt(HSIC(K, K))
    varL = torch.sqrt(HSIC(L, L))

    # Return the CKA value
    return hsic / (varK * varL)

#### Quantitave SAE evaluation ####
def L0_loss(x, threshold=1e-8):
    """
    Expects a tensor x of shape [N_TOKENS, N_SAE].

    Returns a scalar representing the mean value of activated features (i.e. values across the N_SAE dimensions bigger than
    the threshhold), a.k.a. L0 loss.
    """
    return (x > threshold).float().sum(-1).mean()

import plotly.graph_objs as go

import torch
from functools import partial

def get_substitution_loss(tokens, model, sae, sae_layer, reconstruction_metric=None):
    '''
    Expects a tensor of input tokens of shape [N_BATCHES, N_CONTEXT].

    Returns two losses:
    1. Clean loss - loss of the normal forward pass of the model at the input tokens
    2. Substitution loss - loss when substituting SAE reconstructions of the residual stream at the SAE layer of the model
    '''
    batch_size, seq_len = tokens.shape

    # Run the model with cache to get the original activations and clean loss
    loss_clean, cache = model.run_with_cache(tokens, names_filter=[sae_layer], return_type="loss")

    # Fetch and detach the original activations
    original_activations = cache[sae_layer]

    # Get the SAE reconstructed activations (forward pass through SAE)
    post_reconstructed = sae.forward(original_activations)

    # Update the reconstruction quality metric (e.g. R2 score/variance explained)
    if reconstruction_metric:
        reconstruction_metric.update(post_reconstructed.flatten(), original_activations.flatten())

    # Clear the cache and unused variables early
    del original_activations, cache
    torch.cuda.empty_cache()

    # Hook function to substitute activations in-place
    def hook_function(activations, hook, new_activations):
        activations.copy_(new_activations)  # In-place copy to save memory
        return activations

    # Run model again with hooks to substitute activations and get the substitution loss
    loss_reconstructed = model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(sae_layer, partial(hook_function, new_activations=post_reconstructed))]
    )

    # Clean up reconstructed activations and free up memory
    del post_reconstructed
    torch.cuda.empty_cache()

    return loss_clean, loss_reconstructed

import torch
import plotly.graph_objects as go

def plot_log10_hist(y_data, y_value, num_bins=100, first_bin_name = 'First bin value',
                    y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
    """
    Computes the histogram using PyTorch and plots the feature density diagram with log-10 scale using Plotly.
    Y-axis is clipped to the value of the second-largest bin to prevent suppression of smaller values.
    """
    # Flatten the tensor
    y_data_flat = torch.flatten(y_data)

    # Compute the logarithmic transformation using PyTorch
    log_y_data_flat = torch.log10(torch.abs(y_data_flat) + log_epsilon).detach().cpu()

    # Compute histogram using PyTorch
    hist_min = torch.min(log_y_data_flat).item()
    hist_max = torch.max(log_y_data_flat).item()
    hist_range = hist_max - hist_min
    bin_edges = torch.linspace(hist_min, hist_max, num_bins + 1)
    hist_counts, _ = torch.histogram(log_y_data_flat, bins=bin_edges)

    # Convert data to NumPy for Plotly
    bin_edges_np = bin_edges.detach().cpu().numpy()
    hist_counts_np = hist_counts.detach().cpu().numpy()

    # Find the largest and second-largest bin values
    first_bin_value = hist_counts_np[0]
    second_largest_bin_value = sorted(hist_counts_np)[y_scale_bin]  # Get the second largest bin value (by default)

    # Prepare the Plotly plot
    fig = go.Figure(
        data=[go.Bar(
            x=bin_edges_np[:-1],  # Exclude the last bin edge
            y=hist_counts_np,
            width=hist_range / num_bins,
        )]
    )

    # Update the layout for the plot, clipping the y-axis at the second largest bin value
    fig.update_layout(
        title=f"SAE Features {y_value} histogram ({first_bin_name}: {first_bin_value:.2e})",
        xaxis_title=f"Log10 of {y_value}",
        yaxis_title="Density",
        yaxis_range=[0, second_largest_bin_value * y_scalar],  # Clipping to the second-largest value by default
        bargap=0.2,
        bargroupgap=0.1,
    )

    # Add an annotation to display the value of the first bin
    fig.add_annotation(
        text=f"{first_bin_name}: {first_bin_value:.2e}",
        xref="paper", yref="paper",
        x=0.95, y=0.95,
        showarrow=False,
        font=dict(size=12, color="red"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Show the plot
    fig.show()

In [5]:
class FeatureDensityPlotter:
    def __init__(self, n_features, n_tokens, activation_threshold=1e-10, num_bins=100):
        self.num_bins = num_bins
        self.activation_threshold = activation_threshold

        self.n_tokens = n_tokens
        self.n_features = n_features

        # Initialize a tensor of feature densities for all features,
        # where feature density is defined as the fraction of tokens on which the feature has a nonzero value.
        self.feature_densities = torch.zeros(n_features, dtype=torch.float32)

    def update(self, feature_acts):
        """
        Expects a tensor feature_acts of shape [N_TOKENS, N_FEATURES].

        Updates the feature_densities buffer:
        1. For each feature, count the number of tokens that the feature activated on (i.e. had an activation greater than the activation_threshold)
        2. Add this count at the feature's position in the feature_densities tensor, divided by the total number of tokens (to compute the fraction)
        """

        activating_tokens_count = (feature_acts > self.activation_threshold).float().sum(0)
        self.feature_densities += activating_tokens_count / self.n_tokens

    def plot(self, num_bins=100, y_scalar=1.5, y_scale_bin=-2, log_epsilon=1e-10):
        plot_log10_hist(self.feature_densities, 'Density', num_bins=num_bins, first_bin_name='Dead features density',
                        y_scalar=y_scalar, y_scale_bin=y_scale_bin, log_epsilon=log_epsilon)

### Task 4.1 Pretrained case

In [6]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [7]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

(…)id_post_16384_anthropic_fast_lr/cfg.json:   0%|          | 0.00/2.18k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/65.6k [00:00<?, ?B/s]

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [8]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

'relu'

In [9]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens

Downloading readme:   0%|          | 0.00/38.7k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]



(8, 8192)

#### 4.1.1 L0 loss

In [10]:
from tqdm import tqdm

all_tokens_L0 = []  # This will store the tokens for reuse
all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_L0.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
all_tokens_L0 = torch.cat(all_tokens_L0)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 25/25 [00:30<00:00,  1.21s/it]


In [11]:
torch.tensor(all_L0).mean()

tensor(53.6848)

#### 4.1.2 Substitution Loss

In [17]:
from tqdm import tqdm
from torcheval.metrics import R2Score
sae_reconstruction_metric = R2Score().to(device)

all_tokens_SL = []  # This will store the tokens for reuse
all_SL_clean = []
all_SL_reconstructed = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

    # Store tokens for later reuse
    all_tokens_SL.append(tokens)

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, base_model, sae, sae_id, sae_reconstruction_metric)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

# Concatenate all tokens into a single tensor for reuse
all_tokens_SL = torch.cat(all_tokens_SL)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 25/25 [01:38<00:00,  3.93s/it]


In [18]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(2.703125, 3.1953125)

In [19]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


nan

#### 4.1.3 Feature activations histogram

In [20]:
total_histogram_batches = 25

all_histogram_tokens = []
all_feature_acts = []

for k in tqdm(range(total_histogram_batches)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    all_histogram_tokens.append(tokens)

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]  # [N_BATCH, N_CONTEXT, D_SAE]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.flatten(0, 1).to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

all_histogram_tokens = torch.cat(all_histogram_tokens)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 25/25 [00:28<00:00,  1.13s/it]


In [21]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

In [22]:
del all_feature_acts
clear_cache()

#### 4.1.4 Feature density histogram

In [23]:
total_tokens = total_histogram_batches * batch_size_tokens
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_histogram_batches)):
    # Use the same sample to calculate the histogram.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache and convert to float32 for more accurate density computation
    original_activations = cache[sae_id].float()  # [N_BATCH, N_CONTEXT, D_SAE]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.flatten(0, 1).to('cpu')
    assert feature_activations.dtype == torch.float32, str(feature_activations.dtype)

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 25/25 [00:33<00:00,  1.32s/it]


In [24]:
density_plotter.plot(y_scalar=2, y_scale_bin=-2)

In [25]:
del base_model, activation_store
clear_cache()

### Task 4.1 FineTuned case

In [26]:
# Load the finetune model and its tokenizer
finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [27]:
del finetune_model_hf
clear_cache()

In [28]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}' # Gemma is post,

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [29]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=finetune_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23781 [00:00<?, ?it/s]



(8, 8192)

#### 4.2.1 L0 loss

In [30]:
from tqdm import tqdm

all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Use the same sample to calculate the L0 loss.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_L0[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)

    feature_activations.to('cpu')

    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 25/25 [00:28<00:00,  1.16s/it]


In [31]:
torch.tensor(all_L0).mean()

tensor(83.6431)

In [32]:
clear_cache()

#### 4.2.2 Substitution Loss

In [33]:
from tqdm import tqdm
sae_reconstruction_metric = R2Score().to(device)

all_SL_clean = []
all_SL_reconstructed = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Use the same sample to calculate the losses.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens_SL[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Store loss
    clean_loss, reconstructed_loss = get_substitution_loss(tokens, finetune_model, sae, sae_id, sae_reconstruction_metric)
    all_SL_clean.append(clean_loss)
    all_SL_reconstructed.append(reconstructed_loss)

100%|██████████| 25/25 [01:38<00:00,  3.93s/it]


In [34]:
print('Clean vs substitution loss:')
torch.tensor(all_SL_clean).mean().item(), torch.tensor(all_SL_reconstructed).mean().item()

Clean vs substitution loss:


(3.341796875, inf)

In [35]:
print('Varience explained by SAE: ')
sae_reconstruction_metric.compute().item()

Varience explained by SAE: 


1.0

In [36]:
loss_reconstructed_tensor = torch.tensor(all_SL_reconstructed)
loss_reconstructed_tensor.sort()

torch.return_types.sort(
values=tensor([7.2305, 7.4023, 7.4336, 7.4375, 7.4570, 7.4648, 7.4648, 7.5195, 7.5273,
        7.5312, 7.5469, 7.5820, 7.6055, 7.6172, 7.6328, 7.6445, 7.6914, 7.6992,
        7.7227, 7.7852, 7.7930, 7.8047, 7.8828,    inf,    inf],
       dtype=torch.float16),
indices=tensor([24, 15, 18, 13, 11, 12,  4,  6, 14, 21,  5,  3, 23,  7, 17,  0,  9, 22,
         1, 20, 19, 10,  8, 16,  2]))

#### 4.2.3 Feature activations histogram

In [37]:
from tqdm import tqdm

all_feature_acts = []

for k in tqdm(range(total_histogram_batches)):
    # Use the same sample to calculate the histogram.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache at the sae_id
    original_activations = cache[sae_id]  # [N_BATCH, N_CONTEXT, D_SAE]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.flatten(0, 1).to('cpu')

    # Store the encoded activations
    all_feature_acts.append(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 25/25 [00:27<00:00,  1.11s/it]


In [38]:
all_feature_acts = torch.cat(all_feature_acts)
plot_log10_hist(all_feature_acts, 'activations')

#### 4.2.4 Feature density histogram

In [39]:
total_tokens = total_histogram_batches * batch_size_tokens
n_features = sae.cfg.d_sae

density_plotter = FeatureDensityPlotter(n_features, total_tokens)

for k in tqdm(range(total_histogram_batches)):
    # Use the same sample to calculate the histogram.
    # Calculate the start and end indices for the current batch
    start_idx = k * batch_size_prompts
    end_idx = (k + 1) * batch_size_prompts

    # Get the corresponding batch of tokens from all_tokens
    tokens = all_histogram_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

    # Run the model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                             names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]

    # Get the activations from the cache and convert to float32 for more accurate density computation
    original_activations = cache[sae_id].float()  # [N_BATCH, N_CONTEXT, D_SAE]

    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM)
    feature_activations = feature_activations.flatten(0, 1).to('cpu')
    assert feature_activations.dtype == torch.float32, str(feature_activations.dtype)

    # Update the density histogram data
    density_plotter.update(feature_activations)

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

100%|██████████| 25/25 [00:32<00:00,  1.30s/it]


In [40]:
density_plotter.plot(y_scalar=1.5, y_scale_bin=-1)

In [41]:
total_tokens

204800