In [2]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias


import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset
import transformer_lens
import sae_lens

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate


# Hugging face: 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
# Wandb: 6b549d940e7a29c79c184f27f25606e94a48a966

## GPT2

In [16]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")
t.set_grad_enabled(False)
print(device)

cuda


In [4]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [19]:
attn_saes = {
    layer: sae_lens.SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
    )[0]
    for layer in range(gpt2.cfg.n_layers)
}

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v5_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v5.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v4_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v4.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v7_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v7.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v6_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v6.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

(…)c3.16e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

In [20]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


layer = 9

display_dashboard(
    sae_release="gpt2-small-hook-z-kk",
    sae_id=f"blocks.{layer}.hook_z",
    latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)

https://neuronpedia.org/gpt2-small/9-att-kk/2?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [25]:
# Wanda pruning
def wanda_prune(W, X, s):
    # Compute the L2 norm of input features
    X_norm = t.norm(X, dim=0)
    
    # Compute the pruning scores
    scores = t.abs(W) * X_norm
    
    # Determine the pruning threshold
    n_prune = int(W.numel() * s)
    threshold = t.topk(scores.flatten(), n_prune, largest=False)[0][-1]
    
    # Create the mask
    mask = scores > threshold
    
    # Apply the mask to the weights
    W_pruned = W * mask
    
    return W_pruned


W = t.tensor([
    [4, 0, 1, -1], 
    [3, -2, -1, -3], 
    [-3, 1, 0, 2]
    ], dtype=t.float32)
X = t.tensor([1, 2, 8, 3], dtype=t.float32)

true = t.tensor([
    [4, 0, 1, 0], 
    [0, 0, -1, -3], 
    [-3, 0, 0, 2]
    ], dtype=t.float32)

computed = wanda_prune(W.clone(), X, 0.5)
computed

tensor([[ 4.,  0.,  0., -0.],
        [ 3., -2., -0., -3.],
        [-3.,  0.,  0.,  2.]])

## Gemma2

Issue: Pretrained SAE is not available for the model - for attention

In [19]:
gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
