In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias


import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset
import transformer_lens
import sae_lens

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate

from tqdm import tqdm
device = t.device("cuda" if t.cuda.is_available() else "cpu")
t.set_grad_enabled(False)
print(device)

# Hugging face: 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
# Wandb: 6b549d940e7a29c79c184f27f25606e94a48a966

cuda


In [2]:
model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
model.load_state_dict(t.load('pruned_gpt2_magnitude.pth'))
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, model)

Loaded pretrained model gpt2-small into HookedTransformer
Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.82 Prob: 10.75% Token: | have|
Top 1th token. Logit: 14.21 Prob:  5.82% Token: | was|
Top 2th token. Logit: 14.16 Prob:  5.53% Token: |'m|
Top 3th token. Logit: 14.06 Prob:  4.99% Token: |'ve|
Top 4th token. Logit: 13.83 Prob:  3.99% Token: | am|
Top 5th token. Logit: 13.71 Prob:  3.53% Token: | would|
Top 6th token. Logit: 13.57 Prob:  3.08% Token: |.|
Top 7th token. Logit: 13.47 Prob:  2.78% Token: | will|
Top 8th token. Logit: 13.41 Prob:  2.61% Token: | just|
Top 9th token. Logit: 13.36 Prob:  2.49% Token: | think|


In [3]:
model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
model.load_state_dict(t.load('pruned_gpt2_wanda.pth'))
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, model)

Loaded pretrained model gpt2-small into HookedTransformer
Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.94 Prob:  8.72% Token: |'m|
Top 1th token. Logit: 14.91 Prob:  8.47% Token: | have|
Top 2th token. Logit: 14.75 Prob:  7.18% Token: |'ve|
Top 3th token. Logit: 14.65 Prob:  6.49% Token: | was|
Top 4th token. Logit: 14.18 Prob:  4.09% Token: | am|
Top 5th token. Logit: 14.01 Prob:  3.45% Token: | don|
Top 6th token. Logit: 13.81 Prob:  2.83% Token: | think|
Top 7th token. Logit: 13.80 Prob:  2.78% Token: | know|
Top 8th token. Logit: 13.76 Prob:  2.68% Token: | had|
Top 9th token. Logit: 13.68 Prob:  2.48% Token: | can|


## GPT2

In [None]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
     

#### Attention SAE

In [19]:
attn_saes = {
    layer: sae_lens.SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
    )[0]
    for layer in range(gpt2.cfg.n_layers)
}

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v5_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v5.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v4_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v4.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v7_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v7.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/302M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v6_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v6.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

(…)c1.00e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-05_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

(…)c1.00e-05_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-06_rsanthropic_rie25000_nr4_v9_cfg.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

(…)c3.16e-06_rsanthropic_rie25000_nr4_v9.pt:   0%|          | 0.00/151M [00:00<?, ?B/s]

In [43]:
print(attn_saes)

{0: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 1: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 2: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 3: SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
), 4: SAE(
  (activation_fn): ReLU()
  (hook_sae_i

In [20]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


layer = 9

display_dashboard(
    sae_release="gpt2-small-hook-z-kk",
    sae_id=f"blocks.{layer}.hook_z",
    latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)

https://neuronpedia.org/gpt2-small/9-att-kk/2?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [7]:
gpt2.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original

#### Magnitude pruning

In [79]:
def prune_magnitude(W, sparse_ratio=0.5):
    W_abs = W.abs()
    k = int(W_abs.numel() * sparse_ratio)
    _, indices = W_abs.view(-1).topk(k)
    mask = t.zeros_like(W_abs)
    mask.view(-1)[indices] = 1
    return mask*W

def prune_model(model):
    wts = ['W_Q', 'W_K', 'W_V', 'W_O', 'W_in', 'W_out']
    for name, param in gpt2.named_parameters():
        # print(f"Layer: {name}, Shape: {param.shape}")
        if name.split('.')[-1] in wts:
            if param.dim() == 3:    
                for i in range(param.shape[0]):
                    param[i] = prune_magnitude(param[i])
            else:
                param = prune_magnitude(param)
    
    return model

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
pruned_gpt2 = prune_model(gpt2)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)


Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


In [80]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [81]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.82 Prob: 10.75% Token: | have|
Top 1th token. Logit: 14.21 Prob:  5.82% Token: | was|
Top 2th token. Logit: 14.16 Prob:  5.53% Token: |'m|
Top 3th token. Logit: 14.06 Prob:  4.99% Token: |'ve|
Top 4th token. Logit: 13.83 Prob:  3.99% Token: | am|
Top 5th token. Logit: 13.71 Prob:  3.53% Token: | would|
Top 6th token. Logit: 13.57 Prob:  3.08% Token: |.|
Top 7th token. Logit: 13.47 Prob:  2.78% Token: | will|
Top 8th token. Logit: 13.41 Prob:  2.61% Token: | just|
Top 9th token. Logit: 13.36 Prob:  2.49% Token: | think|


In [82]:
transformer_lens.evals.sanity_check(gpt2)

tensor(3.3225, device='cuda:0')

In [83]:
transformer_lens.evals.sanity_check(pruned_gpt2)

tensor(3.8446, device='cuda:0')

In [84]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [85]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, pruned_gpt2, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 17.59 Prob: 31.61% Token: | the|
Top 1th token. Logit: 17.12 Prob: 19.72% Token: | a|
Top 2th token. Logit: 16.68 Prob: 12.66% Token: | his|
Top 3th token. Logit: 15.49 Prob:  3.86% Token: | Mary|
Top 4th token. Logit: 15.43 Prob:  3.64% Token: | one|
Top 5th token. Logit: 15.00 Prob:  2.36% Token: | her|
Top 6th token. Logit: 14.71 Prob:  1.78% Token: | their|
Top 7th token. Logit: 14.47 Prob:  1.40% Token: | an|
Top 8th token. Logit: 14.23 Prob:  1.10% Token: | them|
Top 9th token. Logit: 14.11 Prob:  0.97% Token: | another|


#### Wanda

In [4]:
gpt2.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original

In [5]:
def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

W = t.tensor([
    [4, 0, 1, -1],
    [3, -2, -1, -3],
    [-3, 1, 0, 2]
])
X = t.tensor([
    [1, 2, 8, 3]
])

prune_wanda(W, X, sparse_ratio=0.5)

tensor([[ 4,  0,  1,  0],
        [ 0,  0, -1, -3],
        [-3,  0,  0,  2]])

In [None]:
def prune_model(model, hooks):
    text = "A quick brown fox jumps over the lazy dog."
    tokens = model.to_tokens(text)
    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

    for hook in hooks:
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        layer = int(hook.split('.')[1])

        if 'mlp' in hook:
            if 'out' in hook:
                X = cache[hook].norm(p=2, dim=0)
                model.W_out[layer] = t.nn.Parameter(prune_wanda(model.W_out[layer], X, sparse_ratio=0.5))
            else:
                X = cache[hook]
                X = X.norm(p=2, dim=0)
                model.W_in[layer] = t.nn.Parameter(prune_wanda(model.W_in[layer], X, sparse_ratio=0.5))
        else:
            for head in range(model.cfg.n_heads):
                X = cache[hook][:, head, :]  
                X = X.norm(p=2, dim=0)

                if 'q' in hook:
                    model.W_Q[layer, head] = t.nn.Parameter(prune_wanda(model.W_Q[layer, head], X, sparse_ratio=0.5)
)
                if 'k' in hook:
                    model.W_K[layer, head] = t.nn.Parameter(prune_wanda(model.W_K[layer, head], X, sparse_ratio=0.5))
                
                if 'v' in hook:
                    model.W_V[layer, head] = t.nn.Parameter(prune_wanda(model.W_V[layer, head], X, sparse_ratio=0.5))

    return model

In [22]:
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
gpt2_logits, gpt2_cache = gpt2.run_with_cache(gpt2_tokens, remove_batch_dim=True)

hooks = []
for k, v in gpt2_cache.cache_dict.items():
    if 'block' in k and (v.dim() == 3 or 'mlp' in k):
        if 'score' not in k and 'pattern' not in k:
            hooks.append(k)

len(hooks)
pruned_gpt2 = prune_model(gpt2, hooks)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [23]:
t.norm(pruned_gpt2.W_Q - gpt2.W_Q)

tensor(0., device='cuda:0')

In [25]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [26]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [41]:
X = gpt2_cache['blocks.0.attn.hook_q']
print(X.shape)

torch.Size([11, 12, 64])


In [None]:
X = gpt2_cache['blocks.0.attn.hook_q']
for name, param in gpt2.named_parameters():
    print(name, param.shape)
    if name == 'blocks.0.attn.W_Q':
        W = param
        param = prune_wanda(W, X, sparse_ratio=0.5)
    break


embed.W_E torch.Size([50257, 768])


#### Wanda new try

In [127]:
def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

def prune_model(model, tokens):
    wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
    }
    for layer in range(model.cfg.n_layers):
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = cache[f'blocks.{layer}.{act}']

            if W.dim() == 3:
                if 'W_O' in wt:
                    X_norm = X.norm(p=2, dim=0)
                    for head in range(W.shape[0]):
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
                        
                else:
                    for head in range(W.shape[0]):
                        X_norm = X[:, head, :].norm(p=2, dim=0)
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
            else:
                X_norm = X.norm(p=2, dim=0)
                W = prune_wanda(W, X_norm, sparse_ratio=0.5)
            
    return model

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
print(gpt2_tokens.shape)
pruned_gpt2 = prune_model(gpt2, gpt2_tokens)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer
torch.Size([1, 11])
Loaded pretrained model gpt2-small into HookedTransformer


In [130]:
t.save(pruned_gpt2.state_dict(), 'pruned_gpt2.pth')

In [132]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 15.69 Prob:  9.48% Token: | have|
Top 1th token. Logit: 15.55 Prob:  8.25% Token: | was|
Top 2th token. Logit: 15.42 Prob:  7.22% Token: |'m|
Top 3th token. Logit: 15.28 Prob:  6.32% Token: |'ve|
Top 4th token. Logit: 14.87 Prob:  4.20% Token: | am|
Top 5th token. Logit: 14.68 Prob:  3.46% Token: | think|
Top 6th token. Logit: 14.64 Prob:  3.33% Token: | don|
Top 7th token. Logit: 14.57 Prob:  3.11% Token: | had|
Top 8th token. Logit: 14.46 Prob:  2.78% Token: | know|
Top 9th token. Logit: 14.32 Prob:  2.42% Token: | can|


In [133]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, model)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 15.69 Prob:  9.48% Token: | have|
Top 1th token. Logit: 15.55 Prob:  8.25% Token: | was|
Top 2th token. Logit: 15.42 Prob:  7.22% Token: |'m|
Top 3th token. Logit: 15.28 Prob:  6.32% Token: |'ve|
Top 4th token. Logit: 14.87 Prob:  4.20% Token: | am|
Top 5th token. Logit: 14.68 Prob:  3.46% Token: | think|
Top 6th token. Logit: 14.64 Prob:  3.33% Token: | don|
Top 7th token. Logit: 14.57 Prob:  3.11% Token: | had|
Top 8th token. Logit: 14.46 Prob:  2.78% Token: | know|
Top 9th token. Logit: 14.32 Prob:  2.42% Token: | can|


In [None]:
pruned_gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
dataset = transformer_lens.utils.get_dataset('openwebtext')

class OpenWebText(t.utils.data.Dataset):
    def __init__(self, dataset, max_length=1024):
        self.dataset = dataset
        self.max_length = 1024

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = gpt2.to_tokens(text)
        tokens = tokens[:self.max_length]
        return tokens
    
openwebtext = OpenWebText(dataset)

for batch in tqdm(openwebtext):
    # print(batch.shape)
    pruned_gpt2 = prune_model(pruned_gpt2, batch)

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


TypeError: '>' not supported between instances of 'int' and 'NoneType'

In [87]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 14.56 Prob:  9.34% Token: |'m|
Top 1th token. Logit: 14.51 Prob:  8.81% Token: |'ve|
Top 2th token. Logit: 14.17 Prob:  6.31% Token: | have|
Top 3th token. Logit: 13.88 Prob:  4.74% Token: | am|
Top 4th token. Logit: 13.88 Prob:  4.69% Token: | was|
Top 5th token. Logit: 13.39 Prob:  2.88% Token: | don|
Top 6th token. Logit: 13.29 Prob:  2.61% Token: | think|
Top 7th token. Logit: 13.15 Prob:  2.27% Token: | had|
Top 8th token. Logit: 13.09 Prob:  2.14% Token: | know|
Top 9th token. Logit: 12.82 Prob:  1.64% Token: | can|


In [88]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'I']
Tokenized answer: [' am']


Top 0th token. Logit: 15.69 Prob:  9.48% Token: | have|
Top 1th token. Logit: 15.55 Prob:  8.25% Token: | was|
Top 2th token. Logit: 15.42 Prob:  7.22% Token: |'m|
Top 3th token. Logit: 15.28 Prob:  6.32% Token: |'ve|
Top 4th token. Logit: 14.87 Prob:  4.20% Token: | am|
Top 5th token. Logit: 14.68 Prob:  3.46% Token: | think|
Top 6th token. Logit: 14.64 Prob:  3.33% Token: | don|
Top 7th token. Logit: 14.57 Prob:  3.11% Token: | had|
Top 8th token. Logit: 14.46 Prob:  2.78% Token: | know|
Top 9th token. Logit: 14.32 Prob:  2.42% Token: | can|


In [89]:
transformer_lens.evals.sanity_check(gpt2)

tensor(3.3225, device='cuda:0')

In [90]:
transformer_lens.evals.sanity_check(pruned_gpt2)

tensor(3.6407, device='cuda:0')

In [91]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [92]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, pruned_gpt2, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 17.87 Prob: 42.15% Token: | Mary|
Top 1th token. Logit: 16.95 Prob: 16.90% Token: | the|
Top 2th token. Logit: 15.95 Prob:  6.18% Token: | John|
Top 3th token. Logit: 15.79 Prob:  5.29% Token: | his|
Top 4th token. Logit: 15.70 Prob:  4.82% Token: | them|
Top 5th token. Logit: 15.56 Prob:  4.18% Token: | a|
Top 6th token. Logit: 15.49 Prob:  3.92% Token: | her|
Top 7th token. Logit: 14.36 Prob:  1.27% Token: | him|
Top 8th token. Logit: 14.31 Prob:  1.20% Token: | one|
Top 9th token. Logit: 14.26 Prob:  1.14% Token: | their|


## Gemma2

Issue: Pretrained SAE is not available for the model - for attention

In [None]:
gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
