In [1]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias


import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset
import transformer_lens
import sae_lens

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate
from tqdm import tqdm

Issues: 

* Pretrained SAEs not available for attention

In [2]:
gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device="cuda:0")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [3]:
gemma2b.eval()

HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-25): 26 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

In [4]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"
transformer_lens.utils.test_prompt(prompt, answer, gemma2b)

prompt = "Mary and John went to the park to play. Mary gave the ball to"
answer = " John"

transformer_lens.utils.test_prompt(prompt, answer, gemma2b)

Tokenized prompt: ['<bos>', 'M', 'itig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit:  8.54 Prob: 96.44% Token: | priority|
Top 1th token. Logit:  4.57 Prob:  1.81% Token: | imperative|
Top 2th token. Logit:  3.13 Prob:  0.43% Token: | challenge|
Top 3th token. Logit:  2.82 Prob:  0.31% Token: | goal|
Top 4th token. Logit:  2.59 Prob:  0.25% Token: | effort|
Top 5th token. Logit:  1.91 Prob:  0.13% Token: | concern|
Top 6th token. Logit:  1.21 Prob:  0.06% Token: | security|
Top 7th token. Logit:  1.21 Prob:  0.06% Token: | public|
Top 8th token. Logit:  0.92 Prob:  0.05% Token: |,|
Top 9th token. Logit:  0.84 Prob:  0.04% Token: | activity|


Tokenized prompt: ['<bos>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' park', ' to', ' play', '.', ' Mary', ' gave', ' the', ' ball', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 27.41 Prob: 85.00% Token: | John|
Top 1th token. Logit: 24.90 Prob:  6.86% Token: | her|
Top 2th token. Logit: 23.82 Prob:  2.34% Token: | the|
Top 3th token. Logit: 23.37 Prob:  1.48% Token: | john|
Top 4th token. Logit: 22.98 Prob:  1.01% Token: | a|
Top 5th token. Logit: 22.09 Prob:  0.42% Token: | Mary|
Top 6th token. Logit: 21.95 Prob:  0.36% Token: | Johnny|
Top 7th token. Logit: 21.60 Prob:  0.25% Token: | him|
Top 8th token. Logit: 21.35 Prob:  0.20% Token: | Jack|
Top 9th token. Logit: 20.59 Prob:  0.09% Token: | Joe|


In [5]:
pruned_gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device="cuda:1")
pruned_gemma2b.load_state_dict(t.load("pruned/gemma2b_wanda.pth"))


prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gemma2b)

prompt = "Mary and John went to the park to play. Mary gave the ball to"
answer = " John"

transformer_lens.utils.test_prompt(prompt, answer, pruned_gemma2b)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
Tokenized prompt: ['<bos>', 'M', 'itig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 24.92 Prob: 43.13% Token: | priority|
Top 1th token. Logit: 23.44 Prob:  9.77% Token: | effort|
Top 2th token. Logit: 23.02 Prob:  6.44% Token: | goal|
Top 3th token. Logit: 22.70 Prob:  4.68% Token: | concern|
Top 4th token. Logit: 22.47 Prob:  3.71% Token: | responsibility|
Top 5th token. Logit: 22.26 Prob:  3.02% Token: | endeavor|
Top 6th token. Logit: 22.21 Prob:  2.86% Token: |,|
Top 7th token. Logit: 22.08 Prob:  2.51% Token: | public|
Top 8th token. Logit: 21.74 Prob:  1.79% Token: | imperative|
Top 9th token. Logit: 21.48 Prob:  1.38% Token: | focus|


Tokenized prompt: ['<bos>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' park', ' to', ' play', '.', ' Mary', ' gave', ' the', ' ball', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 26.82 Prob: 77.71% Token: | John|
Top 1th token. Logit: 24.56 Prob:  8.12% Token: | her|
Top 2th token. Logit: 24.33 Prob:  6.44% Token: | the|
Top 3th token. Logit: 23.09 Prob:  1.87% Token: | a|
Top 4th token. Logit: 22.24 Prob:  0.80% Token: | john|
Top 5th token. Logit: 21.89 Prob:  0.56% Token: | Johnny|
Top 6th token. Logit: 21.62 Prob:  0.43% Token: | him|
Top 7th token. Logit: 21.62 Prob:  0.43% Token: | Mary|
Top 8th token. Logit: 21.49 Prob:  0.38% Token: | |
Top 9th token. Logit: 21.37 Prob:  0.33% Token: | one|


In [6]:
pruned_gemma2b: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gemma-2-2b", device="cuda:1")
pruned_gemma2b.load_state_dict(t.load("pruned/gemma2b_magnitude.pth"))


prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gemma2b)

prompt = "Mary and John went to the park to play. Mary gave the ball to"
answer = " John"

transformer_lens.utils.test_prompt(prompt, answer, pruned_gemma2b)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer
Tokenized prompt: ['<bos>', 'M', 'itig', 'ating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 17.32 Prob: 67.70% Token: | priority|
Top 1th token. Logit: 14.93 Prob:  6.16% Token: | effort|
Top 2th token. Logit: 14.22 Prob:  3.05% Token: | goal|
Top 3th token. Logit: 14.07 Prob:  2.61% Token: | public|
Top 4th token. Logit: 13.77 Prob:  1.95% Token: | concern|
Top 5th token. Logit: 13.75 Prob:  1.90% Token: | imperative|
Top 6th token. Logit: 13.44 Prob:  1.39% Token: | responsibility|
Top 7th token. Logit: 13.44 Prob:  1.39% Token: | challenge|
Top 8th token. Logit: 13.21 Prob:  1.11% Token: |,|
Top 9th token. Logit: 12.99 Prob:  0.88% Token: | undertaking|


Tokenized prompt: ['<bos>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' park', ' to', ' play', '.', ' Mary', ' gave', ' the', ' ball', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 26.79 Prob: 78.71% Token: | John|
Top 1th token. Logit: 24.55 Prob:  8.42% Token: | her|
Top 2th token. Logit: 23.89 Prob:  4.33% Token: | the|
Top 3th token. Logit: 23.19 Prob:  2.16% Token: | john|
Top 4th token. Logit: 22.51 Prob:  1.09% Token: | a|
Top 5th token. Logit: 21.97 Prob:  0.64% Token: | Mary|
Top 6th token. Logit: 21.70 Prob:  0.49% Token: | him|
Top 7th token. Logit: 21.56 Prob:  0.42% Token: | Johnny|
Top 8th token. Logit: 21.10 Prob:  0.27% Token: | Jack|
Top 9th token. Logit: 20.88 Prob:  0.21% Token: | me|


In [1]:
import sae_lens
from sae_lens import SAE

release = "gemma-scope-2b-pt-res"
sae_id = "layer_10/width_16k/average_l0_21"
sae, cfg_dict, sparsity = SAE.from_pretrained(release, sae_id)

In [2]:
cfg_dict

{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.10.hook_resid_post',
 'hook_layer': 10,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': 'cpu',
 'neuronpedia_id': None}