In [1]:
import torch

# Evaluate vision SAE

In [2]:
# Load SAE

# Load an SAE
from huggingface_hub import hf_hub_download, list_repo_files
from vit_prisma.sae import SparseAutoencoder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_sae(repo_id, file_name, config_name):
    # Step 1: Download SAE weights and SAE config from Hugginface
    sae_path = hf_hub_download(repo_id, file_name) # Download SAE weights
    hf_hub_download(repo_id, config_name) # Download SAE config

    # Step 2: Now load the pretrained SAE weights from where you just downloaded them
    print(f"Loading SAE from {sae_path}...")
    sae = SparseAutoencoder.load_from_pretrained(sae_path) # This now automatically gets the config.json file in that folder and converts it into a VisionSAERunnerConfig object
    return sae

# Change the repo_id to the Huggingface repo of your chosen SAE. See /docs for a list of SAEs.
repo_id = "Prisma-Multimodal/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-1e-05" 

file_name = "weights.pt" # Usually weights.pt but may have slight naming variation. See the chosen HF repo for the exact file name
config_name = "config.json"

sae = load_sae(repo_id, file_name, config_name)
sae.to(DEVICE)

Loading SAE from /home/mila/s/sonia.joseph/.cache/huggingface/hub/models--Prisma-Multimodal--sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-1e-05/snapshots/b46d6e9d6c114a53364c5c172ea5023e0e52d4e2/weights.pt...




StandardSparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
  (activation_fn): ReLU()
)

In [3]:
# Load model
from vit_prisma.models.model_loader import load_hooked_model


model_name = sae.cfg.model_name
model = load_hooked_model(model_name)
model.to(DEVICE) # Move to device

HookedViT(
  (embed): PatchEmbedding(
    (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  )
  (hook_embed): HookPoint()
  (pos_embed): PosEmbedding()
  (hook_pos_embed): HookPoint()
  (hook_full_embed): HookPoint()
  (ln_pre): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (hook_ln_pre): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): Ho

In [4]:
# Evaluate vision SAE
from vit_prisma.sae import SparsecoderEval

eval_runner = SparsecoderEval(sae, model)

metrics = eval_runner.run_eval(is_clip=True)

Evaluating: 100%|██████████| 391/391 [01:46<00:00,  3.68it/s, L0=3093.3625, Cosine Sim=0.991870]   

Average L0 (features activated): 792.035522
Average L0 (features activated) per CLS token: 965.520325
Average L0 (features activated) per image: 3086.346780
Average Cosine Similarity: 0.9917
Average Loss: 6.762171
Average Reconstruction Loss: 6.762009
Average Zero Ablation Loss: 6.772939
Average CE Score: 1.015572
% CE recovered: 101.496382



