In [1]:
import torch

In [2]:
grad_enabled = torch.is_grad_enabled()
print(f"Torch gradient enabled: {grad_enabled}")

Torch gradient enabled: True


In [3]:
import sae_lens
import torch
import datasets
from typing import List, Tuple
from tqdm import tqdm
def fetch_sae_and_model(
    sae_name: str,
) -> Tuple[List[sae_lens.SAE], sae_lens.HookedSAETransformer]:
    """ fetch the specified SAE and model given the SAE name

    Args:
        sae_name (str): the name of the SAE

    Returns:
        Tuple[List[sae_lens.SAE], sae_lens.HookedSAETransformer]: the SAEs and the model
    """    
    if sae_name == "llama3.1-8b":
        model_name = "meta-llama/Llama-3.1-8B"
        layers = 32
        release = "llama_scope_lxr_8x"
    elif sae_name == "pythia-70m-deduped":
        model_name = "EleutherAI/pythia-70m-deduped"
        layers = 6
        release = "pythia-70m-deduped-res-sm"
    elif sae_name == "gemma-2-2b":
        model_name = "gemma-2-2b"
        layers = 26
        release = "gemma-scope-2b-pt-res-canonical"
    saes = fetch_sae(release, layers)
    model = fetch_model(model_name)
    return saes, model


def fetch_sae(release: str, layers: int) -> sae_lens.SAE:
    saes = []
    for layer in tqdm(range(layers)):
        if release == "gemma-scope-2b-pt-res-canonical":
            sae_id = f"layer_{layer}/width_16k/canonical"
        elif release == "llama_scope_lxr_8x":
            sae_id = f"l{layer}r_8x"
        elif release == "pythia-70m-deduped-res-sm":
            sae_id = f"blocks.{layer}.hook_resid_post"
        sae = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
        sae.to(dtype=torch.bfloat16)
        saes.append(sae)
    return saes


def fetch_model(model_name: str) -> sae_lens.HookedSAETransformer:
    model = sae_lens.HookedSAETransformer.from_pretrained(
        model_name, dtype=torch.bfloat16
    )
    return model

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
saes, model = fetch_sae_and_model("pythia-70m-deduped")

100%|██████████| 6/6 [00:03<00:00,  1.85it/s]
The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [5]:
dataset = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
                "train"
            ]
text = "text"


In [6]:
ds_ratio = 1e-3
dataset_length = int(len(dataset) * ds_ratio)
print(f"Dataset length: {dataset_length}")

Dataset length: 36


In [77]:
for idy in range(10):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # random_indices = torch.randint(0, 32577, (900,))
    # list(map(lambda idx: sae2.W_dec[idx, :].zero_(), random_indices))
    sae2.W_dec[idy*10:(idy+1)*10, :].zero_()
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 16.05it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 6.11964750289917
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 6.11964750289917
Key: blocks.1.ln1.hook_normalized, diff: 6.041692211626593e-14
Key: blocks.1.ln2.hook_normalized, diff: 6.041692211626593e-14
Key: blocks.2.ln1.hook_normalized, diff: 1.6854300547981642e-14


100%|██████████| 36/36 [00:02<00:00, 14.34it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.23993034660816193
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.23993034660816193
Key: blocks.1.ln1.hook_normalized, diff: 2.9389699106042695e-15
Key: blocks.1.ln2.hook_normalized, diff: 2.9389699106042695e-15
Key: blocks.2.ln1.hook_normalized, diff: 7.325880493497261e-16


100%|██████████| 36/36 [00:02<00:00, 15.99it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.0002373579773120582
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.00023735793365631253
Key: blocks.1.ln1.hook_normalized, diff: 3.4758309333979514e-16
Key: blocks.1.ln2.hook_normalized, diff: 3.4758309333979514e-16
Key: blocks.0.hook_resid_post.hook_sae_output, diff: 3.898777512670277e-17


100%|██████████| 36/36 [00:02<00:00, 16.23it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 2.27205228805542
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 2.27205228805542
Key: blocks.1.ln1.hook_normalized, diff: 1.3266534951757863e-13
Key: blocks.1.ln2.hook_normalized, diff: 1.3266534951757863e-13
Key: blocks.2.ln1.hook_normalized, diff: 2.423267241437408e-14


100%|██████████| 36/36 [00:02<00:00, 15.60it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.8473105430603027
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.8473105430603027
Key: blocks.1.ln1.hook_normalized, diff: 1.249702754203396e-14
Key: blocks.1.ln2.hook_normalized, diff: 1.249702754203396e-14
Key: blocks.2.ln1.hook_normalized, diff: 3.2475234727400402e-15


100%|██████████| 36/36 [00:02<00:00, 15.92it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 5.611774921417236
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 5.611774921417236
Key: blocks.1.ln1.hook_normalized, diff: 2.869266849396708e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.869266849396708e-14
Key: blocks.2.ln1.hook_normalized, diff: 5.92818412157377e-15


100%|██████████| 36/36 [00:02<00:00, 15.75it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.8169901371002197
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.8169901371002197
Key: blocks.1.ln1.hook_normalized, diff: 4.137549948924485e-14
Key: blocks.1.ln2.hook_normalized, diff: 4.137549948924485e-14
Key: blocks.2.ln1.hook_normalized, diff: 8.52615572003064e-15


100%|██████████| 36/36 [00:02<00:00, 15.75it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 2.3215298652648926
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 2.3215298652648926
Key: blocks.1.ln1.hook_normalized, diff: 2.10092394056869e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.10092394056869e-14
Key: blocks.2.ln1.hook_normalized, diff: 5.200932220973068e-15


100%|██████████| 36/36 [00:03<00:00, 10.72it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.45195791125297546
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.45195791125297546
Key: blocks.1.ln1.hook_normalized, diff: 1.6388720418193843e-14
Key: blocks.1.ln2.hook_normalized, diff: 1.6388720418193843e-14
Key: blocks.2.ln1.hook_normalized, diff: 2.583527780882053e-15


100%|██████████| 36/36 [00:02<00:00, 15.38it/s]

Key: blocks.5.attn.hook_q, diff: 48896.0
Key: blocks.5.attn.hook_rot_q, diff: 48896.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0





In [78]:
for idy in range(1):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # random_indices = torch.randint(0, 32577, (900,))
    # list(map(lambda idx: sae2.W_dec[idx, :].zero_(), random_indices))
    sae2.W_dec[0:100, :].zero_()
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 15.94it/s]

Key: blocks.5.attn.hook_q, diff: 48896.0
Key: blocks.5.attn.hook_rot_q, diff: 48896.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0





# TODO:
不重要的消融只在局部有微小的影响
重要的消融会直接影响最后几层

In [82]:
for idy in range(1):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # random_indices = torch.randint(0, 32577, (900,))
    # list(map(lambda idx: sae2.W_dec[idx, :].zero_(), random_indices))
    sae2.W_dec[0:50, :].zero_()
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:50]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 15.81it/s]

Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 9.481203079223633
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 9.481203079223633
Key: blocks.1.ln1.hook_normalized, diff: 1.9212249521317892e-13
Key: blocks.1.ln2.hook_normalized, diff: 1.9212249521317892e-13
Key: blocks.2.ln1.hook_normalized, diff: 4.2826673603930596e-14
Key: blocks.2.ln2.hook_normalized, diff: 4.2826673603930596e-14
Key: blocks.0.hook_resid_post.hook_sae_output, diff: 1.9265278188387337e-14
Key: blocks.1.hook_resid_pre, diff: 1.9265278188387337e-14
Key: blocks.3.ln1.hook_normalized, diff: 9.005632272351093e-15
Key: blocks.3.ln2.hook_normalized, diff: 9.005632272351093e-15
Key: blocks.1.hook_resid_post, diff: 7.984138490049666e-15
Key: blocks.2.hook_resid_pre, diff: 7.984138490049666e-15
Key: blocks.2.hook_resid_post, diff: 2.966923691929556e-15
Key: blocks.3.hook_resid_pre, diff: 2.966923691929556e-15
Key: blocks.4.ln1.hook_normalized, diff: 2.155272157915015e-15
Key: blocks.4.ln2.hook_normalized, diff: 2.15




In [88]:
high_diff_keys = []
low_diff_keys = []
for idy in tqdm(range(30)):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # random_indices = torch.randint(0, 32577, (900,))
    # list(map(lambda idx: sae2.W_dec[idx, :].zero_(), random_indices))
    sae2.W_dec[idy*20:(idy+1)*20, :].zero_()
    cache = {}
    for idx in (range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
            
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    flag = False
    for key, _ in top_10_keys:
        if key.startswith("blocks.5"):
            high_diff_keys.append(idy)
            flag = True
            break
    if not flag:
        low_diff_keys.append(idy)

100%|██████████| 30/30 [01:30<00:00,  3.01s/it]


In [86]:
acts = torch.load("../res/acts/pythia_freqs_code.pt")

  acts = torch.load("../res/acts/pythia_freqs_code.pt")


In [89]:
print(f"High diff keys: {high_diff_keys}")
print(f"Low diff keys: {low_diff_keys}")

High diff keys: [4, 24]
Low diff keys: [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29]


In [98]:
for idy in high_diff_keys:
    print(f"Key: {idy}, freq: {idy*20,(idy+1)*20}, freq sum: {acts[0][idy*20:(idy+1)*20].sum()}")

Key: 4, freq: (80, 100), freq sum: 0.007100939285010099
Key: 24, freq: (480, 500), freq sum: 0.18820755183696747


In [97]:
for idy in low_diff_keys:
    print(f"Key: {idy}, freq: {idy*20,(idy+1)*20}, freq sum: {acts[0][idy*20:(idy+1)*20].sum()}")

Key: 0, freq: (0, 20), freq sum: 0.014347216114401817
Key: 1, freq: (20, 40), freq sum: 0.06822924315929413
Key: 2, freq: (40, 60), freq sum: 0.0171254500746727
Key: 3, freq: (60, 80), freq sum: 0.10168902575969696
Key: 5, freq: (100, 120), freq sum: 0.004115692339837551
Key: 6, freq: (120, 140), freq sum: 0.11486919969320297
Key: 7, freq: (140, 160), freq sum: 0.009952389635145664
Key: 8, freq: (160, 180), freq sum: 0.005195060279220343
Key: 9, freq: (180, 200), freq sum: 0.032935481518507004
Key: 10, freq: (200, 220), freq sum: 0.00926438719034195
Key: 11, freq: (220, 240), freq sum: 0.024498241022229195
Key: 12, freq: (240, 260), freq sum: 0.007468082010746002
Key: 13, freq: (260, 280), freq sum: 0.012368181720376015
Key: 14, freq: (280, 300), freq sum: 0.025637242943048477
Key: 15, freq: (300, 320), freq sum: 0.02468782290816307
Key: 16, freq: (320, 340), freq sum: 0.03239729255437851
Key: 17, freq: (340, 360), freq sum: 0.0021987585350871086
Key: 18, freq: (360, 380), freq sum: 0.

In [100]:
high_freq_ind = torch.topk(acts[0], 20).indices
high_freq_ind


tensor([ 2261, 15114, 31698, 23478, 23322,  5929, 31855,  5719,  6873, 22354,
         7851, 19673, 26619, 22183, 14871, 16769, 31778, 15127, 27452, 29834])

In [101]:
for idy in range(1):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    high_freq_ind = torch.topk(acts[0], 20).indices
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), high_freq_ind))
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:50]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 15.58it/s]


Key: blocks.0.hook_resid_post.hook_sae_error, diff: 16.531463623046875
Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 16.531461715698242
Key: blocks.1.ln1.hook_normalized, diff: 2.743424844936504e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.743424844936504e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.1938104544788617e-13
Key: blocks.2.ln2.hook_normalized, diff: 2.1938104544788617e-13
Key: blocks.0.hook_resid_post.hook_sae_output, diff: 1.2656073563287185e-13
Key: blocks.1.hook_resid_pre, diff: 1.2656073563287185e-13
Key: blocks.3.ln1.hook_normalized, diff: 4.480038573190931e-14
Key: blocks.3.ln2.hook_normalized, diff: 4.480038573190931e-14
Key: blocks.1.hook_resid_post, diff: 4.352352760963671e-14
Key: blocks.2.hook_resid_pre, diff: 4.352352760963671e-14
Key: blocks.2.hook_resid_post, diff: 1.480289686401487e-14
Key: blocks.3.hook_resid_pre, diff: 1.480289686401487e-14
Key: blocks.4.ln1.hook_normalized, diff: 1.1919824563424043e-14
Key: blocks.4.ln2.hook_normalized, diff: 1.1

In [103]:

for idy in range(10):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), high_freq_ind))
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 16.10it/s]


Key: blocks.0.hook_resid_post.hook_sae_error, diff: 16.531463623046875
Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 16.531461715698242
Key: blocks.1.ln1.hook_normalized, diff: 2.743424844936504e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.743424844936504e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.1938104544788617e-13


100%|██████████| 36/36 [00:02<00:00, 15.61it/s]


Key: blocks.5.attn.hook_q, diff: 518144.0
Key: blocks.5.attn.hook_rot_q, diff: 518144.0
Key: blocks.4.attn.hook_q, diff: 31488.0
Key: blocks.4.attn.hook_rot_q, diff: 31488.0
Key: blocks.5.attn.hook_z, diff: 1072.0


100%|██████████| 36/36 [00:02<00:00, 15.15it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 13.959872245788574
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 13.959872245788574
Key: blocks.1.ln1.hook_normalized, diff: 1.1705226631714138e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.1705226631714138e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.2373341243900335e-13


100%|██████████| 36/36 [00:02<00:00, 15.68it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 23.384490966796875
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 23.384490966796875
Key: blocks.1.ln1.hook_normalized, diff: 2.567503067790744e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.567503067790744e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.7798745333488797e-13


100%|██████████| 36/36 [00:02<00:00, 16.54it/s]


Key: blocks.5.attn.hook_q, diff: 47872.0
Key: blocks.5.attn.hook_rot_q, diff: 47872.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0


100%|██████████| 36/36 [00:03<00:00, 11.13it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 3.2196216583251953
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 3.2196216583251953
Key: blocks.1.ln1.hook_normalized, diff: 2.5134172087354356e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.5134172087354356e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.6868410422817004e-13


100%|██████████| 36/36 [00:02<00:00, 17.00it/s]


Key: blocks.5.attn.hook_q, diff: 328.0
Key: blocks.5.attn.hook_rot_q, diff: 328.0
Key: blocks.4.attn.hook_q, diff: 49.0
Key: blocks.4.attn.hook_rot_q, diff: 49.0
Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 4.957878589630127


100%|██████████| 36/36 [00:02<00:00, 15.65it/s]


Key: blocks.5.attn.hook_q, diff: 518144.0
Key: blocks.5.attn.hook_rot_q, diff: 518144.0
Key: blocks.4.attn.hook_q, diff: 31488.0
Key: blocks.4.attn.hook_rot_q, diff: 31488.0
Key: blocks.5.attn.hook_z, diff: 1072.0


100%|██████████| 36/36 [00:02<00:00, 16.50it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 2.0361340045928955
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 2.0361340045928955
Key: blocks.1.ln1.hook_normalized, diff: 1.951199618543953e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.951199618543953e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.1108814741570341e-13


100%|██████████| 36/36 [00:02<00:00, 16.29it/s]

Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 9.984151840209961
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 9.984151840209961
Key: blocks.1.ln1.hook_normalized, diff: 1.2955205484777021e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.2955205484777021e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.0608413250650798e-13





In [105]:
import numpy as np
acts = torch.load("../res/acts/pythia_freqs_code.pt"), torch.load("../res/acts/pythia_freqs_math.pt"), torch.load("../res/acts/pythia_freqs_wiki.pt")
code_acts = acts[0]
top_num = 600
math_acts = acts[1]
wiki_acts = acts[2]
top_index_code = torch.topk(code_acts[0], top_num).indices
top_index_math = torch.topk(math_acts[0], top_num).indices
top_index_wiki = torch.topk(wiki_acts[0], top_num).indices
top_index_mc = np.intersect1d(
    top_index_code.cpu().numpy(), top_index_math.cpu().numpy()
)
top_index_mw = np.intersect1d(
    top_index_math.cpu().numpy(), top_index_wiki.cpu().numpy()
)
top_index_cw = np.intersect1d(
    top_index_code.cpu().numpy(), top_index_wiki.cpu().numpy()
)
top_index = np.intersect1d(top_index_mc, top_index_mw)
top_index_mc = np.setdiff1d(top_index_mc, top_index)
top_index_mw = np.setdiff1d(top_index_mw, top_index)
top_index_cw = np.setdiff1d(top_index_cw, top_index)
top_index_wiki = np.setdiff1d(
    top_index_wiki.cpu().numpy(), np.union1d(top_index_cw, top_index_mw)
)
top_index_math = np.setdiff1d(
    top_index_math.cpu().numpy(), np.union1d(top_index_mc, top_index_mw)
)
top_index_code = np.setdiff1d(
    top_index_code.cpu().numpy(), np.union1d(top_index_mc, top_index_cw)
)


  acts = torch.load("../res/acts/pythia_freqs_code.pt"), torch.load("../res/acts/pythia_freqs_math.pt"), torch.load("../res/acts/pythia_freqs_wiki.pt")


In [107]:

for idy in range(10):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), top_index_wiki[idy*20:(idy+1)*20]))
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 15.48it/s]


Key: blocks.5.attn.hook_q, diff: 956.0
Key: blocks.5.attn.hook_rot_q, diff: 956.0
Key: blocks.4.attn.hook_q, diff: 37.25
Key: blocks.4.attn.hook_rot_q, diff: 37.25
Key: blocks.5.attn.hook_z, diff: 17.625


100%|██████████| 36/36 [00:02<00:00, 14.62it/s]


Key: blocks.4.attn.hook_q, diff: 1114112.0
Key: blocks.4.attn.hook_rot_q, diff: 1114112.0
Key: blocks.5.attn.hook_q, diff: 557056.0
Key: blocks.5.attn.hook_rot_q, diff: 557056.0
Key: blocks.5.attn.hook_z, diff: 1344.0


100%|██████████| 36/36 [00:02<00:00, 15.27it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 10.239799499511719
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 10.239799499511719
Key: blocks.1.ln1.hook_normalized, diff: 1.8347571725807477e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.8347571725807477e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.3173855994801087e-13


100%|██████████| 36/36 [00:02<00:00, 15.55it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 6.180593490600586
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 6.180593490600586
Key: blocks.1.ln1.hook_normalized, diff: 1.451221304585304e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.451221304585304e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.9223759682631542e-13


100%|██████████| 36/36 [00:02<00:00, 15.36it/s]


Key: blocks.4.attn.hook_q, diff: 1114112.0
Key: blocks.4.attn.hook_rot_q, diff: 1114112.0
Key: blocks.5.attn.hook_q, diff: 790528.0
Key: blocks.5.attn.hook_rot_q, diff: 790528.0
Key: blocks.5.attn.hook_z, diff: 1544.0


100%|██████████| 36/36 [00:02<00:00, 15.82it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 3.2941243648529053
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 3.2941243648529053
Key: blocks.1.ln1.hook_normalized, diff: 9.386361046054281e-13
Key: blocks.1.ln2.hook_normalized, diff: 9.386361046054281e-13
Key: blocks.2.ln1.hook_normalized, diff: 1.5152428736644324e-13


100%|██████████| 36/36 [00:02<00:00, 14.99it/s]


Key: blocks.5.attn.hook_q, diff: 518144.0
Key: blocks.5.attn.hook_rot_q, diff: 518144.0
Key: blocks.4.attn.hook_q, diff: 31488.0
Key: blocks.4.attn.hook_rot_q, diff: 31488.0
Key: blocks.5.attn.hook_z, diff: 1088.0


100%|██████████| 36/36 [00:03<00:00, 10.05it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 13.944974899291992
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 13.944974899291992
Key: blocks.1.ln1.hook_normalized, diff: 8.709892616171055e-13
Key: blocks.1.ln2.hook_normalized, diff: 8.709892616171055e-13
Key: blocks.2.ln1.hook_normalized, diff: 1.3085825554658842e-13


100%|██████████| 36/36 [00:02<00:00, 15.45it/s]


Key: blocks.4.attn.hook_q, diff: 1114112.0
Key: blocks.4.attn.hook_rot_q, diff: 1114112.0
Key: blocks.5.attn.hook_q, diff: 557056.0
Key: blocks.5.attn.hook_rot_q, diff: 557056.0
Key: blocks.5.attn.hook_z, diff: 1344.0


100%|██████████| 36/36 [00:02<00:00, 15.82it/s]

Key: blocks.5.attn.hook_q, diff: 518144.0
Key: blocks.5.attn.hook_rot_q, diff: 518144.0
Key: blocks.4.attn.hook_q, diff: 31488.0
Key: blocks.4.attn.hook_rot_q, diff: 31488.0
Key: blocks.5.attn.hook_z, diff: 1072.0





In [108]:

for idy in range(10):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), top_index[idy*20:(idy+1)*20]))
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:03<00:00,  9.99it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 13.48265552520752
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 13.48265552520752
Key: blocks.1.ln1.hook_normalized, diff: 2.1958689207235427e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.1958689207235427e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.0715604253686293e-13


100%|██████████| 36/36 [00:02<00:00, 15.24it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 7.814596176147461
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 7.814596176147461
Key: blocks.1.ln1.hook_normalized, diff: 2.325803829042461e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.325803829042461e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.9443379740448352e-13


100%|██████████| 36/36 [00:02<00:00, 16.41it/s]


Key: blocks.4.attn.hook_q, diff: 1114112.0
Key: blocks.4.attn.hook_rot_q, diff: 1114112.0
Key: blocks.5.attn.hook_q, diff: 557056.0
Key: blocks.5.attn.hook_rot_q, diff: 557056.0
Key: blocks.5.attn.hook_z, diff: 1344.0


100%|██████████| 36/36 [00:02<00:00, 16.14it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 5.345946788787842
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 5.345946788787842
Key: blocks.1.ln1.hook_normalized, diff: 1.1913498616442086e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.1913498616442086e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.1123085377182382e-13


100%|██████████| 36/36 [00:02<00:00, 16.15it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 12.04269790649414
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 12.04269790649414
Key: blocks.1.ln1.hook_normalized, diff: 1.5058000419482243e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.5058000419482243e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.9243856725151276e-13


100%|██████████| 36/36 [00:02<00:00, 16.09it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 15.020439147949219
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 15.020439147949219
Key: blocks.1.ln1.hook_normalized, diff: 1.0391405617926619e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.0391405617926619e-12
Key: blocks.2.ln1.hook_normalized, diff: 2.0065345869221818e-13


100%|██████████| 36/36 [00:02<00:00, 16.27it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 14.865419387817383
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 14.865419387817383
Key: blocks.1.ln1.hook_normalized, diff: 1.2863965535153676e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.2863965535153676e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.8092435373224386e-13


100%|██████████| 36/36 [00:02<00:00, 17.05it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 13.859122276306152
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 13.859122276306152
Key: blocks.1.ln1.hook_normalized, diff: 2.238063684031899e-12
Key: blocks.1.ln2.hook_normalized, diff: 2.238063684031899e-12
Key: blocks.2.ln1.hook_normalized, diff: 1.7136722000200139e-13


100%|██████████| 36/36 [00:02<00:00, 16.00it/s]


Key: blocks.5.attn.hook_q, diff: 518144.0
Key: blocks.5.attn.hook_rot_q, diff: 518144.0
Key: blocks.4.attn.hook_q, diff: 31488.0
Key: blocks.4.attn.hook_rot_q, diff: 31488.0
Key: blocks.5.attn.hook_z, diff: 1072.0


100%|██████████| 36/36 [00:02<00:00, 16.81it/s]

Key: blocks.5.attn.hook_q, diff: 47872.0
Key: blocks.5.attn.hook_rot_q, diff: 47872.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0





In [109]:

for idy in range(20):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), top_index_wiki[idy:(idy+1)]))
    cache = {}
    for idx in tqdm(range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    for key, _ in top_10_keys:
        print(f"Key: {key}, diff: {cache[key]}")
    

100%|██████████| 36/36 [00:02<00:00, 15.61it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.09385795146226883
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.09385795146226883
Key: blocks.1.ln1.hook_normalized, diff: 1.1380096355279729e-13
Key: blocks.1.ln2.hook_normalized, diff: 1.1380096355279729e-13
Key: blocks.2.ln1.hook_normalized, diff: 2.0568670464765626e-14


100%|██████████| 36/36 [00:02<00:00, 16.58it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.745274543762207
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.745274543762207
Key: blocks.1.ln1.hook_normalized, diff: 1.8457112194950748e-14
Key: blocks.1.ln2.hook_normalized, diff: 1.8457112194950748e-14
Key: blocks.2.ln1.hook_normalized, diff: 3.159188947769731e-15


100%|██████████| 36/36 [00:02<00:00, 16.79it/s]


Key: blocks.5.attn.hook_q, diff: 48896.0
Key: blocks.5.attn.hook_rot_q, diff: 48896.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0


100%|██████████| 36/36 [00:02<00:00, 15.58it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.5018693208694458
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.5018693208694458
Key: blocks.1.ln1.hook_normalized, diff: 1.3636395615121422e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.3636395615121422e-12
Key: blocks.2.ln1.hook_normalized, diff: 4.0476282536176106e-14


100%|██████████| 36/36 [00:02<00:00, 16.86it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.35401612520217896
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.35401612520217896
Key: blocks.1.ln1.hook_normalized, diff: 2.446167996893872e-13
Key: blocks.1.ln2.hook_normalized, diff: 2.446167996893872e-13
Key: blocks.2.ln1.hook_normalized, diff: 5.1379168660489055e-14


100%|██████████| 36/36 [00:02<00:00, 16.98it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.12169844657182693
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.12169844657182693
Key: blocks.1.ln1.hook_normalized, diff: 3.4933431066483736e-14
Key: blocks.1.ln2.hook_normalized, diff: 3.4933431066483736e-14
Key: blocks.2.ln1.hook_normalized, diff: 3.954900168749893e-15


100%|██████████| 36/36 [00:02<00:00, 17.17it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.0328533761203289
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.0328533761203289
Key: blocks.1.ln1.hook_normalized, diff: 4.102128047509847e-14
Key: blocks.1.ln2.hook_normalized, diff: 4.102128047509847e-14
Key: blocks.2.ln1.hook_normalized, diff: 7.602154582925236e-15


100%|██████████| 36/36 [00:02<00:00, 17.05it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.0928812250494957
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.0928812250494957
Key: blocks.1.ln1.hook_normalized, diff: 2.924057344996442e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.924057344996442e-14
Key: blocks.2.ln1.hook_normalized, diff: 7.519285114530611e-15


100%|██████████| 36/36 [00:02<00:00, 17.56it/s]


Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.029706424102187157
Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.02970641665160656
Key: blocks.1.ln1.hook_normalized, diff: 3.020013303019556e-14
Key: blocks.1.ln2.hook_normalized, diff: 3.020013303019556e-14
Key: blocks.2.ln1.hook_normalized, diff: 3.935749600845421e-15


100%|██████████| 36/36 [00:02<00:00, 17.36it/s]


Key: blocks.5.attn.hook_q, diff: 47872.0
Key: blocks.5.attn.hook_rot_q, diff: 47872.0
Key: blocks.4.attn.hook_q, diff: 16128.0
Key: blocks.4.attn.hook_rot_q, diff: 16128.0
Key: blocks.3.attn.hook_q, diff: 57.0


100%|██████████| 36/36 [00:02<00:00, 15.67it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.5567302703857422
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.5567302703857422
Key: blocks.1.ln1.hook_normalized, diff: 1.2950299469946525e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.2950299469946525e-12
Key: blocks.2.ln1.hook_normalized, diff: 4.117576234401871e-14


100%|██████████| 36/36 [00:02<00:00, 16.06it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.02535983920097351
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.02535983920097351
Key: blocks.1.ln1.hook_normalized, diff: 3.634435756364775e-14
Key: blocks.1.ln2.hook_normalized, diff: 3.634435756364775e-14
Key: blocks.2.ln1.hook_normalized, diff: 6.671486025975519e-15


100%|██████████| 36/36 [00:02<00:00, 16.29it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.022813327610492706
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.022813327610492706
Key: blocks.1.ln1.hook_normalized, diff: 2.435543798161733e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.435543798161733e-14
Key: blocks.2.ln1.hook_normalized, diff: 3.0139293618209597e-15


100%|██████████| 36/36 [00:02<00:00, 16.17it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.012242822907865047
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.012242822907865047
Key: blocks.1.ln1.hook_normalized, diff: 1.0672251758269812e-14
Key: blocks.1.ln2.hook_normalized, diff: 1.0672251758269812e-14
Key: blocks.2.ln1.hook_normalized, diff: 1.1823406803038086e-15


100%|██████████| 36/36 [00:03<00:00, 10.42it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.1277359127998352
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.1277359127998352
Key: blocks.1.ln1.hook_normalized, diff: 2.1234999097118588e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.1234999097118588e-14
Key: blocks.2.ln1.hook_normalized, diff: 2.7613917825530104e-15


100%|██████████| 36/36 [00:02<00:00, 16.07it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.13866767287254333
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.13866767287254333
Key: blocks.1.ln1.hook_normalized, diff: 8.207394182682681e-14
Key: blocks.1.ln2.hook_normalized, diff: 8.207394182682681e-14
Key: blocks.2.ln1.hook_normalized, diff: 2.0701270084526498e-14


100%|██████████| 36/36 [00:02<00:00, 16.50it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.02830544486641884
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.02830544486641884
Key: blocks.1.ln1.hook_normalized, diff: 1.5786461357971196e-14
Key: blocks.1.ln2.hook_normalized, diff: 1.5786461357971196e-14
Key: blocks.2.ln1.hook_normalized, diff: 1.3438445582359044e-15


100%|██████████| 36/36 [00:02<00:00, 15.88it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.20176245272159576
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.20176245272159576
Key: blocks.1.ln1.hook_normalized, diff: 2.007548553122681e-14
Key: blocks.1.ln2.hook_normalized, diff: 2.007548553122681e-14
Key: blocks.0.hook_resid_post.hook_sae_output, diff: 1.6352587263213396e-15


100%|██████████| 36/36 [00:02<00:00, 16.97it/s]


Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.042463518679142
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.042463518679142
Key: blocks.1.ln1.hook_normalized, diff: 4.653581732743002e-14
Key: blocks.1.ln2.hook_normalized, diff: 4.653581732743002e-14
Key: blocks.2.ln1.hook_normalized, diff: 1.0091667424134455e-14


100%|██████████| 36/36 [00:02<00:00, 16.77it/s]

Key: blocks.0.hook_resid_post.hook_sae_recons, diff: 0.12521332502365112
Key: blocks.0.hook_resid_post.hook_sae_error, diff: 0.12521332502365112
Key: blocks.1.ln1.hook_normalized, diff: 1.5102062395772053e-12
Key: blocks.1.ln2.hook_normalized, diff: 1.5102062395772053e-12
Key: blocks.0.hook_resid_post.hook_sae_output, diff: 1.2372499299227475e-14





In [111]:
high_diff_keys = []
low_diff_keys = []
for idy in tqdm(range(20)):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), top_index_wiki[idy:(idy+1)]))
    cache = {}
    for idx in (range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    flag = False
    for key, _ in top_10_keys:
        if key.startswith("blocks.5"):
            high_diff_keys.append(idy)
            flag = True
            break
    if not flag:
        low_diff_keys.append(idy)
    

100%|██████████| 20/20 [00:56<00:00,  2.80s/it]


In [113]:
print(f"High diff keys: {high_diff_keys}")
print(f"Low diff keys: {low_diff_keys}")

High diff keys: [2, 9]
Low diff keys: [0, 1, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]


In [116]:
code_acts[0][top_index_wiki[2]], code_acts[0][top_index_wiki[9]]

(tensor(0.0020), tensor(0.1736))

In [119]:
import jaxtyping
def get_cosine_similarity(
    dict_elements_1: jaxtyping.Float[torch.Tensor, "d_sae d_llm"],
    dict_elements_2: jaxtyping.Float[torch.Tensor, "d_sae d_llm"],
    p: int = 2,
    dim: int = 1,
    normalized: bool = True,
) -> jaxtyping.Float[torch.Tensor, "d_llm d_llm"]:
    """Get the cosine similarity between the dictionary elements.

    Args:
        dict_elements_1: The first dictionary elements.
        dict_elements_2: The second dictionary elements.

    Returns:
        The cosine similarity between the dictionary elements.
    """
    # Compute cosine similarity in pytorch
    dict_elements_1 = dict_elements_1
    dict_elements_2 = dict_elements_2

    # Normalize the tensors
    if normalized:
        dict_elements_1 = torch.nn.functional.normalize(dict_elements_1, p=p, dim=dim)
        dict_elements_2 = torch.nn.functional.normalize(dict_elements_2, p=p, dim=dim)

    # Compute cosine similarity using matrix multiplication
    cosine_sim: jaxtyping.Float[torch.Tensor, "d_llm d_llm"] = torch.mm(
        dict_elements_1, dict_elements_2.T
    )
    # max_cosine_sim, _ = torch.max(cosine_sim, dim=1)
    return cosine_sim

In [144]:
def cs_2_dict(key1, key2):
    return torch.mm(
        torch.nn.functional.normalize(
            sae1.W_dec[top_index_wiki[key1].item(), :].unsqueeze(0)
        ),
        torch.nn.functional.normalize(
            sae1.W_dec[top_index_wiki[key2].item(), :].unsqueeze(0)
        ).T,
    )

In [149]:
for idx in range(len(low_diff_keys)):
    for idy in range(idx+1, len(low_diff_keys)):
        print(f"Key1: {low_diff_keys[idx]}, Key2: {low_diff_keys[idy]}, cosine similarity: {cs_2_dict(low_diff_keys[idx], low_diff_keys[idy])}")

Key1: 0, Key2: 1, cosine similarity: tensor([[-0.0873]], device='cuda:0')
Key1: 0, Key2: 3, cosine similarity: tensor([[0.0549]], device='cuda:0')
Key1: 0, Key2: 4, cosine similarity: tensor([[0.0655]], device='cuda:0')
Key1: 0, Key2: 5, cosine similarity: tensor([[0.1952]], device='cuda:0')
Key1: 0, Key2: 6, cosine similarity: tensor([[0.0752]], device='cuda:0')
Key1: 0, Key2: 7, cosine similarity: tensor([[0.1955]], device='cuda:0')
Key1: 0, Key2: 8, cosine similarity: tensor([[-0.0638]], device='cuda:0')
Key1: 0, Key2: 10, cosine similarity: tensor([[0.0563]], device='cuda:0')
Key1: 0, Key2: 11, cosine similarity: tensor([[0.0815]], device='cuda:0')
Key1: 0, Key2: 12, cosine similarity: tensor([[-0.0236]], device='cuda:0')
Key1: 0, Key2: 13, cosine similarity: tensor([[0.0669]], device='cuda:0')
Key1: 0, Key2: 14, cosine similarity: tensor([[-0.0653]], device='cuda:0')
Key1: 0, Key2: 15, cosine similarity: tensor([[0.1375]], device='cuda:0')
Key1: 0, Key2: 16, cosine similarity: ten

In [150]:
for key1 in high_diff_keys:
    for key2 in high_diff_keys:
        print(f"Key1: {key1}, Key2: {key2}, cosine similarity: {cs_2_dict(key1, key2)}")

Key1: 2, Key2: 2, cosine similarity: tensor([[1.]], device='cuda:0')
Key1: 2, Key2: 9, cosine similarity: tensor([[-0.2480]], device='cuda:0')
Key1: 9, Key2: 2, cosine similarity: tensor([[-0.2480]], device='cuda:0')
Key1: 9, Key2: 9, cosine similarity: tensor([[1.0000]], device='cuda:0')


In [156]:
high_diff_keys = []
low_diff_keys = []
for idy in tqdm(range(20)):
    model_name = "EleutherAI/pythia-70m-deduped"
    layers = 6
    torch.set_grad_enabled(False)
    release = "pythia-70m-deduped-res-sm"
    sae_id = "blocks.0.hook_resid_post"
    sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    ds_ratio = 1e-3
    dataset_length = int(len(dataset) * ds_ratio)
    use_error_term = True
    freqs = torch.zeros(sae1.cfg.d_sae)
    doc_len = 0
    sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
    # high_freq_ind = torch.topk(acts[0], 200).indices[idy*20:(idy+1)*20]
    list(map(lambda idx: sae2.W_dec[idx, :].zero_(), top_index_wiki[idy:(idy+1)]))
    cache = {}
    loss = 0
    for idx in (range(dataset_length)):
        example = dataset[idx]
        tokens = model.to_tokens([example[text]], prepend_bos=True)
        loss1, cache1 = model.run_with_cache_with_saes(
            tokens, saes=sae1, use_error_term=use_error_term
        )
        model.reset_saes()
        local_doc_len = cache1[f"blocks.0.hook_resid_post.hook_sae_acts_post"].shape[1]
        new_doc_len = doc_len + local_doc_len

        freq = (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"]
            > 1)[0].sum(0) / local_doc_len
        loss2, cache2 = model.run_with_cache_with_saes(
            tokens, saes=sae2, use_error_term=use_error_term
        )
        model.reset_saes()
        for keys in cache1.keys():
            if torch.isnan(cache1[keys]).any() or torch.isnan(cache2[keys]).any():
                continue
            res = ((cache1[keys] - cache2[keys]) ** 2).sum()
            if torch.isnan(res):
                if idx == 0:
                    cache[keys] = 0
                continue
            if idx == 0:
                cache[keys] = res
            else:
                cache[keys] = (
                    cache[keys] * doc_len / new_doc_len
                    + res * local_doc_len / new_doc_len
                )
        doc_len = new_doc_len
        loss += ((loss1 - loss2) ** 2).sum()
    def value_getter(item):
        return item[1]

    top_10_keys = sorted(cache.items(), key=value_getter, reverse=True)[:5]
    flag = False
    for key, _ in top_10_keys:
        if key.startswith("blocks.5"):
            high_diff_keys.append(idy)
            print(f"high diff, the loss diff is {loss}")
            flag = True
            break
    if not flag:
        low_diff_keys.append(idy)
        print(f"low diff, the loss diff is {loss}")
    

  5%|▌         | 1/20 [00:02<00:48,  2.56s/it]

low diff, the loss diff is 0.0


 10%|█         | 2/20 [00:05<00:48,  2.68s/it]

low diff, the loss diff is 0.0


 15%|█▌        | 3/20 [00:08<00:46,  2.76s/it]

high diff, the loss diff is 109568.0


 20%|██        | 4/20 [00:10<00:43,  2.69s/it]

low diff, the loss diff is 0.0


 25%|██▌       | 5/20 [00:13<00:39,  2.65s/it]

low diff, the loss diff is 0.0


 30%|███       | 6/20 [00:15<00:36,  2.62s/it]

low diff, the loss diff is 0.0


 35%|███▌      | 7/20 [00:18<00:32,  2.52s/it]

low diff, the loss diff is 0.0


 40%|████      | 8/20 [00:32<01:13,  6.16s/it]

low diff, the loss diff is 0.0


 45%|████▌     | 9/20 [00:34<00:55,  5.04s/it]

low diff, the loss diff is 0.0


 50%|█████     | 10/20 [00:37<00:42,  4.22s/it]

high diff, the loss diff is 90112.0


 55%|█████▌    | 11/20 [00:39<00:33,  3.71s/it]

low diff, the loss diff is 0.0


 60%|██████    | 12/20 [00:42<00:26,  3.29s/it]

low diff, the loss diff is 0.0


 65%|██████▌   | 13/20 [00:44<00:21,  3.02s/it]

low diff, the loss diff is 0.0


 70%|███████   | 14/20 [00:47<00:17,  2.91s/it]

low diff, the loss diff is 0.0


 75%|███████▌  | 15/20 [00:49<00:13,  2.78s/it]

low diff, the loss diff is 0.0


 80%|████████  | 16/20 [00:52<00:10,  2.69s/it]

low diff, the loss diff is 0.0


 85%|████████▌ | 17/20 [00:54<00:07,  2.61s/it]

low diff, the loss diff is 0.0


 90%|█████████ | 18/20 [00:56<00:05,  2.55s/it]

low diff, the loss diff is 0.0


 95%|█████████▌| 19/20 [00:59<00:02,  2.58s/it]

low diff, the loss diff is 0.0


100%|██████████| 20/20 [01:03<00:00,  3.17s/it]

low diff, the loss diff is 0.0





In [None]:
# TODO1: check the structure of the model and see the location of the hook
# TODO2: check the attribute methods and see its influence on the output and frequency patterns
# TODO3: cos sim and high dim vis

hook_embed: Likely used to capture or modify the embeddings of the input tokens before they are fed into the model.

blocks.0.hook_resid_pre: Captures or modifies the residual connection input before any processing in the first block.

blocks.0.ln1.hook_scale: Captures or modifies the scaling factor in the first layer normalization of the first block.

blocks.0.ln1.hook_normalized: Captures or modifies the normalized output in the first layer normalization of the first block.

blocks.0.attn.hook_q: Captures or modifies the query vectors in the attention mechanism of the first block.

blocks.0.attn.hook_k: Captures or modifies the key vectors in the attention mechanism of the first block.

blocks.0.attn.hook_v: Captures or modifies the value vectors in the attention mechanism of the first block.

blocks.0.attn.hook_rot_q: Captures or modifies the rotated query vectors, possibly for rotary positional embeddings.

blocks.0.attn.hook_rot_k: Captures or modifies the rotated key vectors, possibly for rotary positional embeddings.

blocks.0.attn.hook_attn_scores: Captures or modifies the attention scores before the softmax operation.

blocks.0.attn.hook_pattern: Captures or modifies the attention pattern after the softmax operation.

blocks.0.attn.hook_z: Captures or modifies the output of the attention mechanism.

blocks.0.hook_attn_out: Captures or modifies the output of the attention mechanism before it is added to the residual connection.

blocks.0.ln2.hook_scale: Captures or modifies the scaling factor in the second layer normalization of the first block.

blocks.0.ln2.hook_normalized: Captures or modifies the normalized output in the second layer normalization of the first block.

blocks.0.mlp.hook_pre: Captures or modifies the input to the multi-layer perceptron (MLP) in the first block.

blocks.0.mlp.hook_post: Captures or modifies the output of the MLP in the first block.

blocks.0.hook_mlp_out: Captures or modifies the output of the MLP before it is added to the residual connection.

blocks.0.hook_resid_post.hook_sae_input: Likely captures or modifies the input to a sub-module or sub-layer within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_acts_pre: Likely captures or modifies the activations before some specific operation within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_acts_post: Likely captures or modifies the activations after some specific operation within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_recons: Likely captures or modifies the reconstructed output within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_output: Likely captures or modifies the final output within the residual connection post-processing.