In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias


import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset
import transformer_lens
import sae_lens

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate


# Hugging face: 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
# Wandb: 6b549d940e7a29c79c184f27f25606e94a48a966

## GPT2

In [2]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")
t.set_grad_enabled(False)
print(device)

cuda


In [3]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


#### Attention SAE

In [19]:
attn_saes = {
    layer: sae_lens.SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
    )[0]
    for layer in range(gpt2.cfg.n_layers)
}

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v5_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v5.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v4_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v4.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v7_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v7.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v6_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v6.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

(…)c3.16e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

In [43]:
print(attn_saes)

{0: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 1: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 2: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 3: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 4: SAE(
  (activation_fn): ReLU()
  (hook_sae_i

In [20]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


layer = 9

display_dashboard(
    sae_release="gpt2-small-hook-z-kk",
    sae_id=f"blocks.{layer}.hook_z",
    latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)

https://neuronpedia.org/gpt2-small/9-att-kk/2?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [7]:
gpt2.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original

#### Magnitude pruning

In [6]:
for name, param in gpt2.named_parameters():
    print(f"Layer: {name}, Shape: {param.shape}")

Layer: embed.W_E, Shape: torch.Size([50257, 768])
Layer: pos_embed.W_pos, Shape: torch.Size([1024, 768])
Layer: blocks.0.attn.W_Q, Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.W_O, Shape: torch.Size([12, 64, 768])
Layer: blocks.0.attn.b_Q, Shape: torch.Size([12, 64])
Layer: blocks.0.attn.b_O, Shape: torch.Size([768])
Layer: blocks.0.attn.W_K, Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.W_V, Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.b_K, Shape: torch.Size([12, 64])
Layer: blocks.0.attn.b_V, Shape: torch.Size([12, 64])
Layer: blocks.0.mlp.W_in, Shape: torch.Size([768, 3072])
Layer: blocks.0.mlp.b_in, Shape: torch.Size([3072])
Layer: blocks.0.mlp.W_out, Shape: torch.Size([3072, 768])
Layer: blocks.0.mlp.b_out, Shape: torch.Size([768])
Layer: blocks.1.attn.W_Q, Shape: torch.Size([12, 768, 64])
Layer: blocks.1.attn.W_O, Shape: torch.Size([12, 64, 768])
Layer: blocks.1.attn.b_Q, Shape: torch.Size([12, 64])
Layer: blocks.1.attn.b_O, Shape: torch.Size([768])
Laye

In [7]:
def prune_magnitude(W, sparse_ratio=0.5):
    W_abs = W.abs()
    k = int(W_abs.numel() * sparse_ratio)
    _, indices = W_abs.view(-1).topk(k)
    mask = t.zeros_like(W_abs)
    mask.view(-1)[indices] = 1
    return mask*W

def prune_model(model):
    for name, param in gpt2.named_parameters():
        # print(f"Layer: {name}, Shape: {param.shape}")
        if 'W' in name:
            for i in range(param.shape[0]):
                with t.no_grad():
                    param[i] = prune_magnitude(param[i])
    
    return model

pruned_gpt2 = prune_model(gpt2)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)


Loaded pretrained model gpt2-small into HookedTransformer


In [71]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [74]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.23 Prob:  9.70% Token: |b|
Top 1th token. Logit: 13.81 Prob:  6.36% Token: | was|
Top 2th token. Logit: 13.80 Prob:  6.31% Token: |van|
Top 3th token. Logit: 13.48 Prob:  4.54% Token: |on|
Top 4th token. Logit: 12.97 Prob:  2.73% Token: | have|
Top 5th token. Logit: 12.81 Prob:  2.32% Token: |nex|
Top 6th token. Logit: 12.73 Prob:  2.16% Token: |e|
Top 7th token. Logit: 12.66 Prob:  2.01% Token: |-|
Top 8th token. Logit: 12.61 Prob:  1.90% Token: |'m|
Top 9th token. Logit: 12.30 Prob:  1.40% Token: |'ve|


In [46]:
transformer_lens.evals.sanity_check(gpt2)

tensor(3.3225, device='cuda:0')

In [48]:
transformer_lens.evals.sanity_check(pruned_gpt2)

tensor(3.8446, device='cuda:0')

In [52]:
dl = transformer_lens.evals.make_pile_data_loader(gpt2.tokenizer)
transformer_lens.evals.evaluate_on_dataset(gpt2, dl, truncate=1000, device=device)

10000


  0%|          | 0/2119 [00:00<?, ?it/s]

2.9206282215994914

In [53]:
transformer_lens.evals.evaluate_on_dataset(pruned_gpt2, dl, truncate=1000,  device=device)

  0%|          | 0/2119 [00:00<?, ?it/s]

3.793479422589282

#### Wanda

In [4]:
gpt2.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original

In [5]:
def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

W = t.tensor([
    [4, 0, 1, -1],
    [3, -2, -1, -3],
    [-3, 1, 0, 2]
])
X = t.tensor([
    [1, 2, 8, 3]
])

prune_wanda(W, X, sparse_ratio=0.5)

tensor([[ 4,  0,  1,  0],
        [ 0,  0, -1, -3],
        [-3,  0,  0,  2]])

In [None]:
def prune_model(model, hooks):
    text = "A quick brown fox jumps over the lazy dog."
    tokens = model.to_tokens(text)
    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

    for hook in hooks:
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        layer = int(hook.split('.')[1])

        if 'mlp' in hook:
            if 'out' in hook:
                X = cache[hook].norm(p=2, dim=0)
                model.W_out[layer] = t.nn.Parameter(prune_wanda(model.W_out[layer], X, sparse_ratio=0.5))
            else:
                X = cache[hook]
                X = X.norm(p=2, dim=0)
                model.W_in[layer] = t.nn.Parameter(prune_wanda(model.W_in[layer], X, sparse_ratio=0.5))
        else:
            for head in range(model.cfg.n_heads):
                X = cache[hook][:, head, :]  
                X = X.norm(p=2, dim=0)

                if 'q' in hook:
                    model.W_Q[layer, head] = t.nn.Parameter(prune_wanda(model.W_Q[layer, head], X, sparse_ratio=0.5)
)
                if 'k' in hook:
                    model.W_K[layer, head] = t.nn.Parameter(prune_wanda(model.W_K[layer, head], X, sparse_ratio=0.5))
                
                if 'v' in hook:
                    model.W_V[layer, head] = t.nn.Parameter(prune_wanda(model.W_V[layer, head], X, sparse_ratio=0.5))

    return model

In [22]:
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
gpt2_logits, gpt2_cache = gpt2.run_with_cache(gpt2_tokens, remove_batch_dim=True)

hooks = []
for k, v in gpt2_cache.cache_dict.items():
    if 'block' in k and (v.dim() == 3 or 'mlp' in k):
        if 'score' not in k and 'pattern' not in k:
            hooks.append(k)

len(hooks)
pruned_gpt2 = prune_model(gpt2, hooks)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [23]:
t.norm(pruned_gpt2.W_Q - gpt2.W_Q)

tensor(0., device='cuda:0')

In [25]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [26]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [41]:
X = gpt2_cache['blocks.0.attn.hook_q']
print(X.shape)

torch.Size([11, 12, 64])


In [None]:
X = gpt2_cache['blocks.0.attn.hook_q']
for name, param in gpt2.named_parameters():
    print(name, param.shape)
    if name == 'blocks.0.attn.W_Q':
        W = param
        param = prune_wanda(W, X, sparse_ratio=0.5)
    break


embed.W_E torch.Size([50257, 768])


#### Wanda new try

In [24]:
for name, param in pruned_gpt2.named_parameters():
    print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
blocks.1.attn.W_Q torch.Size([12, 768, 64])
blocks.1.attn.W_O torch.Size([12, 64, 768])
blocks.1.attn.b_Q torch.Size([12, 64])
blocks.1.attn.b_O torch.Size([768])
blocks.1.attn.W_K torch.Size([12, 768, 64])
blocks.1.attn.W_V torch.Size([12, 768, 64])
blocks.1.attn.b_K torch.Size([12, 64])
blocks.1.attn.b_V torch.Size([12, 64])
blocks.1.mlp.W_in torch.Size([768, 3072])
blocks.1.mlp.b_in torch.Size([3072])
blocks.1.mlp.W_out torch.Size

In [21]:
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
gpt2_logits, gpt2_cache = gpt2.run_with_cache(gpt2_tokens, remove_batch_dim=True)

for k, v in gpt2_cache.cache_dict.items():
    print(k, v.shape)

hook_embed torch.Size([11, 768])
hook_pos_embed torch.Size([11, 768])
blocks.0.hook_resid_pre torch.Size([11, 768])
blocks.0.ln1.hook_scale torch.Size([11, 1])
blocks.0.ln1.hook_normalized torch.Size([11, 768])
blocks.0.attn.hook_q torch.Size([11, 12, 64])
blocks.0.attn.hook_k torch.Size([11, 12, 64])
blocks.0.attn.hook_v torch.Size([11, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([12, 11, 11])
blocks.0.attn.hook_pattern torch.Size([12, 11, 11])
blocks.0.attn.hook_z torch.Size([11, 12, 64])
blocks.0.hook_attn_out torch.Size([11, 768])
blocks.0.hook_resid_mid torch.Size([11, 768])
blocks.0.ln2.hook_scale torch.Size([11, 1])
blocks.0.ln2.hook_normalized torch.Size([11, 768])
blocks.0.mlp.hook_pre torch.Size([11, 3072])
blocks.0.mlp.hook_post torch.Size([11, 3072])
blocks.0.hook_mlp_out torch.Size([11, 768])
blocks.0.hook_resid_post torch.Size([11, 768])
blocks.1.hook_resid_pre torch.Size([11, 768])
blocks.1.ln1.hook_scale torch.Size([11, 1])
blocks.1.ln1.hook_normalized torch.Size

In [29]:
acts = ['attn.hook_q','attn.hook_k','attn.hook_v','hook_attn_out','mlp.hook_pre','hook_mlp_out']
wts = ['attn.W_Q', 'attn.W_K', 'attn.W_V', 'mlp.W_in', 'mlp.W_out']

wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
}

for wt, act in wts_act.items():
    print(wt, act)

attn.W_Q attn.hook_q
attn.W_K attn.hook_k
attn.W_V attn.hook_v
attn.W_O hook_attn_out
mlp.W_in mlp.hook_pre
mlp.W_out hook_mlp_out


In [34]:
gpt2.get_parameter('blocks.0.attn.W_Q').shape

torch.Size([12, 768, 64])

In [41]:


def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

def prune_model(model, tokens):
    wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
    }
    for layer in range(model.cfg.n_layers):
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = cache[f'blocks.{layer}.{act}']

            if W.dim() == 3:
                if 'W_O' in wt:
                    X_norm = X.norm(p=2, dim=0)
                    for head in range(W.shape[0]):
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
                        
                else:
                    for head in range(W.shape[0]):
                        X_norm = X[:, head, :].norm(p=2, dim=0)
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
            else:
                X_norm = X.norm(p=2, dim=0)
                W = prune_wanda(W, X_norm, sparse_ratio=0.5)

    return model

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
pruned_gpt2 = prune_model(gpt2, gpt2_tokens)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


In [42]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [43]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 15.69 Prob:  9.48% Token: | have|
Top 1th token. Logit: 15.55 Prob:  8.25% Token: | was|
Top 2th token. Logit: 15.42 Prob:  7.22% Token: |'m|
Top 3th token. Logit: 15.28 Prob:  6.32% Token: |'ve|
Top 4th token. Logit: 14.87 Prob:  4.20% Token: | am|
Top 5th token. Logit: 14.68 Prob:  3.46% Token: | think|
Top 6th token. Logit: 14.64 Prob:  3.33% Token: | don|
Top 7th token. Logit: 14.57 Prob:  3.11% Token: | had|
Top 8th token. Logit: 14.46 Prob:  2.78% Token: | know|
Top 9th token. Logit: 14.32 Prob:  2.42% Token: | can|


## Gemma2

Issue: Pretrained SAE is not available for the model - for attention

In [None]:
gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
