In [1]:
import sae_lens
import torch
import jaxtyping
import random
import datasets
import plotly.colors as pc
import plotly.express as px
import seaborn as sns
import numpy as np
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

def obtain_data() -> (
    Tuple[List[sae_lens.SAE], torch.nn.Module, torch.utils.data.Dataset]
):
    """
    load sae, model and dataset
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    layers = 12
    saes = []
    from sae_lens import SAE

    release = "gpt2-small-res-jb"

    model_name = "gpt2-small"
    for layer in tqdm(range(layers)):
        sae_id = f"blocks.{layer}.hook_resid_pre"
        saes.append(
            sae_lens.SAE.from_pretrained(release=release, sae_id=sae_id, device=device)[
                0
            ]
        )

    model = sae_lens.HookedSAETransformer.from_pretrained(model_name)
    ds = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")["train"]

    return saes, model, ds

saes, model, ds = obtain_data()


  from .autonotebook import tqdm as notebook_tqdm
This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)
100%|██████████| 12/12 [00:51<00:00,  4.32s/it]


Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
math_dataset = datasets.load_dataset("TIGER-Lab/MathInstruct")
CodeXGlue_dataset = datasets.load_dataset("BAAI/TACO")["train"]

Generating train split: 100%|██████████| 262039/262039 [00:01<00:00, 260029.59 examples/s]


In [5]:
math_dataset[0]

KeyError: "Invalid key: 0. Please first select a split. For example: `my_dataset_dictionary['train'][0]`. Available splits: ['train']"

In [10]:
nz_all = []
doc_len = 0
layers = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
freqs = torch.zeros(layers, saes[0].cfg.d_sae).to(device)
abl_layer = 0
abl_times = 20
length_ds = 100
ds_ratio = 1e-2
length_ds = int(len(ds) * ds_ratio)
for idx in tqdm(range(length_ds)):
    example = ds[idx]
    tokens = model.to_tokens([example["text"]], prepend_bos=True)
    _, cache1 = model.run_with_cache_with_saes(tokens, saes=saes, use_error_term=False)
    local_doc_len = cache1[f"blocks.0.hook_resid_pre.hook_sae_acts_post"].shape[1]
    freq = torch.zeros_like(freqs)
    new_doc_len = doc_len + local_doc_len
    for layer in range(layers):
        prompt2 = f"blocks.{layer}.hook_resid_pre.hook_sae_acts_post"
        freq[layer] = (cache1[prompt2] > 1e-3)[0].sum(0) / local_doc_len    
    if idx == 0:
        freqs = freq
    else:
        freqs = (
            freqs * doc_len / new_doc_len
            + freq * local_doc_len / new_doc_len
        )
    doc_len = new_doc_len

100%|██████████| 367/367 [00:14<00:00, 25.39it/s]


In [12]:
torch.save(freqs, "../res/acts/wiki-gpt2-small-res-all12-acts.pt")

In [13]:
nz_all = []
doc_len = 0
layers = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
freqs = torch.zeros(layers, saes[0].cfg.d_sae).to(device)
abl_layer = 0
abl_times = 20
length_ds = 100
ds_ratio = 1e-2
length_ds = int(len(CodeXGlue_dataset) * ds_ratio)
for idx in tqdm(range(length_ds)):
    example = CodeXGlue_dataset[idx]
    tokens = model.to_tokens([example["solutions"]], prepend_bos=True)
    _, cache1 = model.run_with_cache_with_saes(tokens, saes=saes, use_error_term=False)
    local_doc_len = cache1[f"blocks.0.hook_resid_pre.hook_sae_acts_post"].shape[1]
    freq = torch.zeros_like(freqs)
    new_doc_len = doc_len + local_doc_len
    for layer in range(layers-1):
        prompt2 = f"blocks.{layer}.hook_resid_pre.hook_sae_acts_post"
        freq[layer] = (cache1[prompt2] > 1e-3)[0].sum(0) / local_doc_len    
    if idx == 0:
        freqs = freq
    else:
        freqs = (
            freqs * doc_len / new_doc_len
            + freq * local_doc_len / new_doc_len
        )
    doc_len = new_doc_len

100%|██████████| 254/254 [00:16<00:00, 15.30it/s]


In [14]:
torch.save(freqs, "../res/acts/code-gpt2-small-res-all12-acts.pt")

In [18]:
math_dataset['train'][123]

{'source': 'data/PoT/mathqa.json',
 'output': 'n0 = 2.3\nn1 = 60.0\nn2 = 3.0\nn3 = 75.0\nt0 = n1 / 2.0\nt1 = n3 - n1\nt2 = t0 / t1\nanswer = n2 + t2\nprint(answer)',
 'instruction': 'a thief steals a car at 2.30 pm and drives it at 60 kmph . the theft is discovered at 3 pm and the owner sets off in another car at 75 kmph when will he overtake the thief ? Please respond by writing a program in Python.'}

In [21]:
nz_all = []
doc_len = 0
layers = 12
device = "cuda" if torch.cuda.is_available() else "cpu"
freqs = torch.zeros(layers, saes[0].cfg.d_sae).to(device)
abl_layer = 0
abl_times = 20
length_ds = 100
ds_ratio = 1e-3
length_ds = int(len(math_dataset['train']) * ds_ratio)
for idx in tqdm(range(length_ds)):
    example = math_dataset['train'][idx]
    tokens = model.to_tokens([example["output"]], prepend_bos=True)
    _, cache1 = model.run_with_cache_with_saes(tokens, saes=saes, use_error_term=False)
    local_doc_len = cache1[f"blocks.0.hook_resid_pre.hook_sae_acts_post"].shape[1]
    freq = torch.zeros_like(freqs)
    new_doc_len = doc_len + local_doc_len
    for layer in range(layers-1):
        prompt2 = f"blocks.{layer}.hook_resid_pre.hook_sae_acts_post"
        freq[layer] = (cache1[prompt2] > 1e-3)[0].sum(0) / local_doc_len    
    if idx == 0:
        freqs = freq
    else:
        freqs = (
            freqs * doc_len / new_doc_len
            + freq * local_doc_len / new_doc_len
        )
    doc_len = new_doc_len

100%|██████████| 262/262 [00:10<00:00, 23.90it/s]


In [22]:
torch.save(freqs, "../res/acts/math-gpt2-small-res-all12-acts.pt")