In [1]:
import sae_lens
import torch
import jaxtyping
import random
import datasets
import plotly.colors as pc
import plotly.express as px
import seaborn as sns
import numpy as np
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

def obtain_data() -> (
    Tuple[List[sae_lens.SAE], torch.nn.Module, torch.utils.data.Dataset]
):
    """
    load sae, model and dataset
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    layers = 6
    saes = []
    release = "pythia-70m-deduped-res-sm"
    model_name = "pythia-70m-deduped"
    for layer in tqdm(range(layers)):
        sae_id = f"blocks.{layer}.hook_resid_post"
        saes.append(
            sae_lens.SAE.from_pretrained(release=release, sae_id=sae_id, device=device)[
                0
            ]
        )

    model = sae_lens.HookedSAETransformer.from_pretrained(model_name)
    ds = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")["train"]

    return saes, model, ds

saes, model, ds = obtain_data()

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 6/6 [00:03<00:00,  1.92it/s]
The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [2]:
code_acts = torch.load("../res/acts/BAAI-TACO-pythia70m-res-all6-acts.pt")
math_acts = torch.load("../res/acts/math-pythia70m-res-all6-acts.pt")
wiki_acts = torch.load("../res/acts/wiki-pythia70m-res-all6-acts.pt")

top_num = 325
top_index_code = torch.topk(code_acts, top_num).indices

  code_acts = torch.load("../res/acts/BAAI-TACO-pythia70m-res-all6-acts.pt")
  math_acts = torch.load("../res/acts/math-pythia70m-res-all6-acts.pt")
  wiki_acts = torch.load("../res/acts/wiki-pythia70m-res-all6-acts.pt")


In [3]:
len(ds)

36718

In [4]:
import copy
nz_all = []
doc_len = 0
freq_mean_global = 0
layers = 6
device = "cuda" if torch.cuda.is_available() else "cpu"
freqs = torch.zeros(saes[0].cfg.d_sae).to(device)
abl_layer = 0
abl_times = 10
for layer in range(layers-1):
    nz_freqs = []
    abl_layer = layer
    for idx in tqdm(range(abl_times)):
        saes2 = copy.deepcopy(saes)
        with torch.no_grad():
            abl_num = 29
            list(map(lambda idy: saes2[abl_layer].W_dec[idy, :].zero_(), top_index_code[0][abl_num*idx:abl_num*(idx+1)]))
        ds_ratio = 1e-3
        length_ds = int(len(ds) * ds_ratio)
        for idx in range(length_ds):
            # loop begin, fuck indent
            example = ds[idx]
            tokens = model.to_tokens([example["text"]], prepend_bos=True)
            _, cache1 = model.run_with_cache_with_saes(tokens, saes=saes, use_error_term=False)
            model.reset_saes()
            _, cache2 = model.run_with_cache_with_saes(tokens, saes=saes2, use_error_term=False)
            local_doc_len = cache1["blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
            freq = torch.zeros_like(freqs)
            
            prompt2 = f"blocks.{abl_layer + 1}.hook_resid_post.hook_sae_acts_post"
            freq = (((cache1[prompt2] > 1e-3) + 0 + cache2[prompt2] > 1e-3)==1)[0].sum(
                0
            ) / local_doc_len
            # freq[layer] = (cache[prompt2] > 1e-3)[0].sum(0) / local_doc_len
            new_doc_len = doc_len + local_doc_len
            if idx == 0:
                freq_mean_global = freq
            else:
                freq_mean_global = (
                    freq_mean_global * doc_len / new_doc_len
                    + freq * local_doc_len / new_doc_len
                )
            doc_len = new_doc_len
        nz_freqs.append(freq_mean_global)
    nz_all.append(nz_freqs)

100%|██████████| 10/10 [00:14<00:00,  1.47s/it]
100%|██████████| 10/10 [00:13<00:00,  1.36s/it]
100%|██████████| 10/10 [00:13<00:00,  1.33s/it]
100%|██████████| 10/10 [00:12<00:00,  1.21s/it]
100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


In [19]:
res_ = [[], [], [], [], []]

for layer in range(layers-1):
    for idx in range(len(nz_freqs)):
        res = np.intersect1d(top_index_code[layer+1].cpu().numpy(), (nz_all[layer][idx]>1e-3).nonzero().view(-1).cpu().numpy())
        print(f'common iou: {res.shape} abl_layer: {layer} influence freq: {nz_all[layer][idx].nonzero().view(-1).shape}')
        res_[layer].append(res)
last_iou = res_[4][0]
for idx in range(len(res_[4]) - 1):
    iou = np.intersect1d(last_iou, res_[4][idx])
    last_iou = iou
    print(f"iou with last abl: {iou.shape}")

common iou: (293,) abl_layer: 0 influence freq: torch.Size([8181])
common iou: (287,) abl_layer: 0 influence freq: torch.Size([8297])
common iou: (278,) abl_layer: 0 influence freq: torch.Size([8135])
common iou: (273,) abl_layer: 0 influence freq: torch.Size([8203])
common iou: (265,) abl_layer: 0 influence freq: torch.Size([8099])
common iou: (258,) abl_layer: 0 influence freq: torch.Size([8169])
common iou: (250,) abl_layer: 0 influence freq: torch.Size([8135])
common iou: (246,) abl_layer: 0 influence freq: torch.Size([8132])
common iou: (238,) abl_layer: 0 influence freq: torch.Size([8145])
common iou: (232,) abl_layer: 0 influence freq: torch.Size([8113])
common iou: (286,) abl_layer: 1 influence freq: torch.Size([6998])
common iou: (282,) abl_layer: 1 influence freq: torch.Size([7003])
common iou: (280,) abl_layer: 1 influence freq: torch.Size([6999])
common iou: (280,) abl_layer: 1 influence freq: torch.Size([7004])
common iou: (279,) abl_layer: 1 influence freq: torch.Size([69

In [13]:
res_ = [[], [], [], [], []]
for layer in range(layers-1):
    for idx in range(len(nz_freqs)):
        res = np.intersect1d(top_index_code[layer+1].cpu().numpy(), nz_all[layer][idx].nonzero().view(-1).cpu().numpy())
        print(f'common iou: {res.shape} abl_layer: {layer} influence freq: {nz_all[layer][idx].nonzero().view(-1).shape}')
        res_[layer].append(res)
last_iou = res_[4][0]
for idx in range(len(res_) - 1):
    iou = np.intersect1d(last_iou, res_[4][idx])
    last_iou = iou
    print(f"iou with last abl: {iou.shape}")

common iou: (312,) abl_layer: 0 influence freq: torch.Size([8181])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8297])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8135])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8203])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8099])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8169])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8135])
common iou: (313,) abl_layer: 0 influence freq: torch.Size([8132])
common iou: (312,) abl_layer: 0 influence freq: torch.Size([8145])
common iou: (313,) abl_layer: 0 influence freq: torch.Size([8113])
common iou: (321,) abl_layer: 1 influence freq: torch.Size([6998])
common iou: (321,) abl_layer: 1 influence freq: torch.Size([7003])
common iou: (321,) abl_layer: 1 influence freq: torch.Size([6999])
common iou: (321,) abl_layer: 1 influence freq: torch.Size([7004])
common iou: (321,) abl_layer: 1 influence freq: torch.Size([69

In [7]:
cache1[prompt2][0].nonzero()

tensor([[    0,   899],
        [    1,  1333],
        [    1,  2653],
        ...,
        [   79, 26770],
        [   79, 28835],
        [   79, 31976]], device='cuda:0')

In [8]:
cache2[prompt2][0].nonzero()

tensor([[    0,   899],
        [    1,  1333],
        [    1,  2653],
        ...,
        [   79, 26770],
        [   79, 28835],
        [   79, 31976]], device='cuda:0')

In [9]:
freq = (((cache1[prompt2] > 1e-3) + 0 + cache2[prompt2] > 1e-3)==1)[0].sum(
                0
            ) / local_doc_len
(freq>1e-3).nonzero().shape

torch.Size([4138, 1])