# Evaluating your SAE

Code based off Rob Graham's ([themachinefan](https://github.com/themachinefan)) SAE evaluation code.

In [53]:
import torch
import torchvision

import plotly.express as px

from tqdm import tqdm

import einops

import numpy as np
import os


# Setup

In [77]:
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig

@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = '/network/scratch/s/sonia.joseph/sae_checkpoints/1f89d99e-wkcn-TinyCLIP-ViT-40M-32-Text-19M-LAION400M-expansion-16/n_images_520028.pt'
    model_name: str = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
    model_type: str =  "clip"

    dataset_path = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets"
    dataset_train_path: str = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets/ILSVRC/Data/CLS-LOC/train"
    dataset_val_path: str = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets/ILSVRC/Data/CLS-LOC/val"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000
    batch_size: int = 32
   


cfg = EvalConfig()

n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 31738
Total training images: 2600000
Total wandb updates: 1057
n_tokens_per_feature_sampling_window (millions): 20.48
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 317 times.
Number tokens in sparsity calculation window: 4.10e+05


In [78]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7efbb02fa890>

## Load model

In [68]:
from vit_prisma.models.base_vit import HookedViT

model_name = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True).to(cfg.device)
 

{'n_layers': 12, 'd_model': 512, 'd_head': 64, 'model_name': '', 'n_heads': 8, 'd_mlp': 2048, 'activation_name': 'gelu', 'eps': 1e-05, 'original_architecture': 'vit_clip_vision_encoder', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 512, 'n_params': None, 'layer_norm_pre': True, 'return_type': 'class_logits'}
LayerNorm folded.
Centered weights writing to residual stream
Loaded pretrained model wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M into HookedTransformer


## Load datasets

In [69]:
# load dataset
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_transforms_clip, ImageNetValidationDataset


if cfg.model_type == 'clip':
    data_transforms = get_imagenet_transforms_clip(cfg.model_name)
else:
    raise ValueError("Invalid model type")
imagenet_paths = setup_imagenet_paths(cfg.dataset_path)
train_data = torchvision.datasets.ImageFolder(cfg.dataset_train_path, transform=data_transforms)
val_data = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'], 
                                data_transforms
)
# print(f"Train data length: {len(train_data)}") if cfg.verbose else None
print(f"Validation data length: {len(val_data)}") if cfg.verbose else None


Validation data length: 50000


In [70]:
from vit_prisma.sae.training.activations_store import VisionActivationsStore

activations_loader = VisionActivationsStore(cfg, model, train_data, eval_dataset=val_data)


## Load pretrained SAE to evaluate

In [71]:
from vit_prisma.sae.sae import SparseAutoencoder
sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained(cfg.sae_path)
sparse_autoencoder.to(cfg.device)
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

# Evaluate sparsity

## Average L0

In [79]:
# sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


with torch.no_grad():
    batch_tokens, labels = activations_loader.get_val_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, names_filter = sparse_autoencoder.cfg.hook_point)
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point].to(cfg.device)
    )
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, :] > 0).float().sum(-1).detach()
    l0_cls = (feature_acts[:, :] > 0).float().sum(-1).mean(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

average l0 22.189374923706055


## Get feature probability

In [108]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

@torch.no_grad()
def get_feature_probability(images, model, sparse_autoencoder):
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )

    # Flatten first two dimensions (batch, position) to get a 2D tensor of activations
    return (feature_acts.abs() > 0).float().flatten(0, 1)

def process_dataset(val_dataloader, model, sparse_autoencoder, cfg):
    total_acts = None
    total_tokens = 0
    
    for idx, batch in tqdm(enumerate(val_dataloader), total=cfg.eval_max//cfg.batch_size):
        images = batch[0]

        images = images.to(cfg.device)
        sae_activations = get_feature_probability(images, model, sparse_autoencoder)
        
        if total_acts is None:
            total_acts = sae_activations.sum(0)
        else:
            total_acts += sae_activations.sum(0)
        
        total_tokens += sae_activations.shape[0]
        
        if count * batch_size >= cfg.eval_max:
            break
    
    return total_acts, count * batch_size

def calculate_log_frequencies(total_acts, total_tokens, epsilon=1e-10):
    feature_probs = total_acts / total_tokens
    log_feature_probs = torch.log10(feature_probs + epsilon)
    return log_feature_probs.cpu().numpy()

# def plot_histogram(log_frequencies, num_bins=100): # Note: black edged histograms look great!
#     plt.figure(figsize=(12, 6))
#     plt.hist(log_frequencies, bins=num_bins, edgecolor='black')
#     plt.xlabel('Log10 Feature Frequency')
#     plt.ylabel('Count')
#     plt.title('Log Feature Density Histogram')
#     plt.grid(True, alpha=0.3)
#     plt.show()


# Main execution
total_acts, total_tokens = process_dataset(activations_loader.image_dataloader_eval, model, sparse_autoencoder, cfg)

log_frequencies = calculate_log_frequencies(total_acts, total_tokens)

print(f"Total tokens processed: {total_tokens}")
print(f"Average activations per token: {total_acts.sum().item() / total_tokens:.4f}")



 64%|██████▍   | 1002/1562 [03:19<04:25,  2.11it/s]

In [107]:
def plot_histogram_px(log_frequencies, num_bins=100):
    fig = px.histogram(
        x=log_frequencies,
        nbins=num_bins,
        labels={'x': 'Log10 Feature Frequency', 'y': 'Count'},
        title='Log Feature Density Histogram',
        opacity=0.7,
    )
    fig.update_layout(
        bargap=0.1,
        xaxis_title='Log10 Feature Frequency',
        yaxis_title='Count',
        plot_bgcolor='rgba(240, 240, 240, 0.8)',  # Light gray background
        xaxis=dict(showgrid=True, gridwidth=1, gridcolor='White'),
        yaxis=dict(showgrid=True, gridwidth=1, gridcolor='White'),
    )
    fig.show()

    
plot_histogram_px(log_frequencies, num_bins=240)

In [102]:
log_freq = torch.Tensor(log_frequencies)

# minimum and maximum log_freq
min_log_freq = log_freq.min().item()
max_log_freq = log_freq.max().item()

print(f"Minimum log frequency: {min_log_freq:.4f}")
print(f"Maximum log frequency: {max_log_freq:.4f}")

Minimum log frequency: -4.2217
Maximum log frequency: 1.1507


In [98]:
def visualize_sparsities(log_freq, conditions, condition_texts, name):
    # Visualise sparsities for each instance
    hist(
        log_freq,
        f"{name}_frequency_histogram",
        show=True,
        title=f"{name} Log Frequency of Features",
        labels={"x": "log<sub>10</sub>(freq)"},
        histnorm="percent",
        template="ggplot2"
    )

    #TODO these conditions need to be tuned to distribution of your data!


    for condition, condition_text in zip(conditions, condition_texts):
        percentage = (torch.count_nonzero(condition)/log_freq.shape[0]).item()*100
        if percentage == 0:
            continue
        percentage = int(np.round(percentage))
        rare_encoder_directions = sparse_autoencoder.W_enc[:, condition]
        rare_encoder_directions_normalized = rare_encoder_directions / rare_encoder_directions.norm(dim=0, keepdim=True)

        # Compute their pairwise cosine similarities & sample randomly from this N*N matrix of similarities
        cos_sims_rare = (rare_encoder_directions_normalized.T @ rare_encoder_directions_normalized).flatten()
        cos_sims_rare_random_sample = cos_sims_rare[torch.randint(0, cos_sims_rare.shape[0], (10000,))]

        # Plot results
        hist(
            cos_sims_rare_random_sample,
            f"{name}_low_prop_similarity_{condition_text}",
            show=True,
            marginal="box",
            title=f"{name} Cosine similarities of random {condition_text} encoder directions with each other ({percentage}% of features)",
            labels={"x": "Cosine sim"},
            histnorm="percent",
            template="ggplot2",
        )

conditions = [torch.logical_and(log_freq < -4,log_freq > -8), log_freq <-8, torch.logical_and(log_freq < -4,log_freq > -6.5),torch.logical_and(log_freq < -6.5,log_freq > -8), torch.logical_and(log_freq < -4,log_freq > -5),log_freq>-4]
condition_texts = ["logfreq_[-8,-4]", "logfreq_[-inf,-8]", "logfreq_[-6.5,-4]", "logfreq_[-8,-6.5]", "logfreq_[-5,-4]", "logfreq_[-4,inf]"]
visualize_sparsities(log_freq, conditions, condition_texts, "TOTAL")
conditions_class = [torch.logical_and(log_freq_class < -4,log_freq_class > -8), log_freq_class <-9, log_freq_class>-4]
condition_texts_class = ["logfreq_[-8,-4]", "logfreq_[-inf,-9]","logfreq_[-4,inf]"]
visualize_sparsities(log_freq_class, conditions_class, condition_texts_class,"CLS")