In [20]:
import torch
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f20a8ce4200>

In [21]:
grad_enabled = torch.is_grad_enabled()
print(f"Torch gradient enabled: {grad_enabled}")

Torch gradient enabled: False


In [23]:
import sae_lens
import torch
import datasets
from typing import List, Tuple
from tqdm import tqdm
def fetch_sae_and_model(
    sae_name: str,
) -> Tuple[List[sae_lens.SAE], sae_lens.HookedSAETransformer]:
    """ fetch the specified SAE and model given the SAE name

    Args:
        sae_name (str): the name of the SAE

    Returns:
        Tuple[List[sae_lens.SAE], sae_lens.HookedSAETransformer]: the SAEs and the model
    """    
    if sae_name == "llama3.1-8b":
        model_name = "meta-llama/Llama-3.1-8B"
        layers = 32
        release = "llama_scope_lxr_8x"
    elif sae_name == "pythia-70m-deduped":
        model_name = "EleutherAI/pythia-70m-deduped"
        layers = 6
        release = "pythia-70m-deduped-res-sm"
    elif sae_name == "gemma-2-2b":
        model_name = "gemma-2-2b"
        layers = 26
        release = "gemma-scope-2b-pt-res-canonical"
    saes = fetch_sae(release, layers)
    model = fetch_model(model_name)
    return saes, model


def fetch_sae(release: str, layers: int) -> sae_lens.SAE:
    saes = []
    for layer in tqdm(range(layers)):
        if release == "gemma-scope-2b-pt-res-canonical":
            sae_id = f"layer_{layer}/width_16k/canonical"
        elif release == "llama_scope_lxr_8x":
            sae_id = f"l{layer}r_8x"
        elif release == "pythia-70m-deduped-res-sm":
            sae_id = f"blocks.{layer}.hook_resid_post"
        sae = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
        sae.to(dtype=torch.bfloat16)
        saes.append(sae)
    return saes


def fetch_model(model_name: str) -> sae_lens.HookedSAETransformer:
    model = sae_lens.HookedSAETransformer.from_pretrained(
        model_name, dtype=torch.bfloat16
    )
    return model

In [31]:
saes, model = fetch_sae_and_model("pythia-70m-deduped")

100%|██████████| 6/6 [00:01<00:00,  3.63it/s]


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


In [28]:
dataset = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
                "train"
            ]
text = "text"


In [30]:
ds_ratio = 1e-3
dataset_length = int(len(dataset) * ds_ratio)
print(f"Dataset length: {dataset_length}")

Dataset length: 36


In [45]:
model_name = "EleutherAI/pythia-70m-deduped"
layers = 6
release = "pythia-70m-deduped-res-sm"
sae_id = "blocks.0.hook_resid_post"
sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae2 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae2.W_dec[:300, :].zero_()


for idx in tqdm(range(dataset_length)):
    example = dataset[idx]
    tokens = model.to_tokens([example[text]], prepend_bos=True)
    loss1, cache1 = model.run_with_cache_with_saes(
        tokens, saes=sae1, use_error_term=False
    )
    model.reset_saes()
    loss2, cache2 = model.run_with_cache_with_saes(
        tokens, saes=sae2, use_error_term=False
    )
    model.reset_saes()
    

100%|██████████| 36/36 [00:01<00:00, 33.23it/s]


In [47]:
differences = []
keys = []
for k, _ in cache1.items():
    keys.append(k)
    difference = (cache1[k] - cache2[k]).sum()
    differences.append(difference)
print(f"Max difference: {max(differences)}")

Max difference: 1096.0


In [62]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Convert differences to numpy arrays for plotting
differences_np = [diff.to(torch.float32).cpu().numpy() for diff in differences]

# Create a DataFrame for plotting
df_stat = []
for k, diff in zip(keys, differences_np):
    if diff > 50:
        df_stat.append(
            pd.DataFrame(
                {
                    "Layer": [k],
                    "Difference": [diff],
                }
            )
        )
print(pd.concat(df_stat))
# # Plot the differences
# plt.figure(figsize=(12, 6))
# sns.barplot(x="Layer", y="Difference", data=pd.concat(df_stat))
# plt.xticks(rotation=90)
# plt.title("Mean Absolute Differences Across Layers")
# plt.xlabel("Layer")
# plt.ylabel("Mean Absolute Difference")
# plt.show()

                    Layer Difference
0   blocks.1.mlp.hook_pre      728.0
0   blocks.2.mlp.hook_pre      284.0
0  blocks.2.mlp.hook_post       72.0
0   blocks.3.mlp.hook_pre       67.5
0  blocks.3.mlp.hook_post      53.25
0    blocks.4.attn.hook_k       66.5
0   blocks.4.mlp.hook_pre      556.0
0    blocks.5.attn.hook_k      217.0
0    blocks.5.attn.hook_v      121.0
0   blocks.5.mlp.hook_pre     1096.0


In [53]:
for k, diff in zip(keys, differences):
    if diff > 50:
        print(f"Key: {k}, Difference: {diff}")

Key: blocks.1.mlp.hook_pre, Difference: 728.0
Key: blocks.2.mlp.hook_pre, Difference: 284.0
Key: blocks.2.mlp.hook_post, Difference: 72.0
Key: blocks.3.mlp.hook_pre, Difference: 67.5
Key: blocks.3.mlp.hook_post, Difference: 53.25
Key: blocks.4.attn.hook_k, Difference: 66.5
Key: blocks.4.mlp.hook_pre, Difference: 556.0
Key: blocks.5.attn.hook_k, Difference: 217.0
Key: blocks.5.attn.hook_v, Difference: 121.0
Key: blocks.5.mlp.hook_pre, Difference: 1096.0


In [63]:
for k, v in cache1.items():
    print(f"Key: {k}")

Key: hook_embed
Key: blocks.0.hook_resid_pre
Key: blocks.0.ln1.hook_scale
Key: blocks.0.ln1.hook_normalized
Key: blocks.0.attn.hook_q
Key: blocks.0.attn.hook_k
Key: blocks.0.attn.hook_v
Key: blocks.0.attn.hook_rot_q
Key: blocks.0.attn.hook_rot_k
Key: blocks.0.attn.hook_attn_scores
Key: blocks.0.attn.hook_pattern
Key: blocks.0.attn.hook_z
Key: blocks.0.hook_attn_out
Key: blocks.0.ln2.hook_scale
Key: blocks.0.ln2.hook_normalized
Key: blocks.0.mlp.hook_pre
Key: blocks.0.mlp.hook_post
Key: blocks.0.hook_mlp_out
Key: blocks.0.hook_resid_post.hook_sae_input
Key: blocks.0.hook_resid_post.hook_sae_acts_pre
Key: blocks.0.hook_resid_post.hook_sae_acts_post
Key: blocks.0.hook_resid_post.hook_sae_recons
Key: blocks.0.hook_resid_post.hook_sae_output
Key: blocks.1.hook_resid_pre
Key: blocks.1.ln1.hook_scale
Key: blocks.1.ln1.hook_normalized
Key: blocks.1.attn.hook_q
Key: blocks.1.attn.hook_k
Key: blocks.1.attn.hook_v
Key: blocks.1.attn.hook_rot_q
Key: blocks.1.attn.hook_rot_k
Key: blocks.1.attn.hook

In [None]:
# TODO1: check the structure of the model and see the location of the hook
# TODO2: check the attribute methods and see its influence on the output and frequency patterns
# TODO3: cos sim and high dim vis

hook_embed: Likely used to capture or modify the embeddings of the input tokens before they are fed into the model.

blocks.0.hook_resid_pre: Captures or modifies the residual connection input before any processing in the first block.

blocks.0.ln1.hook_scale: Captures or modifies the scaling factor in the first layer normalization of the first block.

blocks.0.ln1.hook_normalized: Captures or modifies the normalized output in the first layer normalization of the first block.

blocks.0.attn.hook_q: Captures or modifies the query vectors in the attention mechanism of the first block.

blocks.0.attn.hook_k: Captures or modifies the key vectors in the attention mechanism of the first block.

blocks.0.attn.hook_v: Captures or modifies the value vectors in the attention mechanism of the first block.

blocks.0.attn.hook_rot_q: Captures or modifies the rotated query vectors, possibly for rotary positional embeddings.

blocks.0.attn.hook_rot_k: Captures or modifies the rotated key vectors, possibly for rotary positional embeddings.

blocks.0.attn.hook_attn_scores: Captures or modifies the attention scores before the softmax operation.

blocks.0.attn.hook_pattern: Captures or modifies the attention pattern after the softmax operation.

blocks.0.attn.hook_z: Captures or modifies the output of the attention mechanism.

blocks.0.hook_attn_out: Captures or modifies the output of the attention mechanism before it is added to the residual connection.

blocks.0.ln2.hook_scale: Captures or modifies the scaling factor in the second layer normalization of the first block.

blocks.0.ln2.hook_normalized: Captures or modifies the normalized output in the second layer normalization of the first block.

blocks.0.mlp.hook_pre: Captures or modifies the input to the multi-layer perceptron (MLP) in the first block.

blocks.0.mlp.hook_post: Captures or modifies the output of the MLP in the first block.

blocks.0.hook_mlp_out: Captures or modifies the output of the MLP before it is added to the residual connection.

blocks.0.hook_resid_post.hook_sae_input: Likely captures or modifies the input to a sub-module or sub-layer within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_acts_pre: Likely captures or modifies the activations before some specific operation within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_acts_post: Likely captures or modifies the activations after some specific operation within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_recons: Likely captures or modifies the reconstructed output within the residual connection post-processing.

blocks.0.hook_resid_post.hook_sae_output: Likely captures or modifies the final output within the residual connection post-processing.

In [64]:
import jaxtyping

@torch.no_grad()
def get_cosine_similarity(
    dict_elements_1: jaxtyping.Float[torch.Tensor, "d_sae d_llm"],
    dict_elements_2: jaxtyping.Float[torch.Tensor, "d_sae d_llm"],
    p: int = 2,
    dim: int = 1,
    normalized: bool = True,
) -> jaxtyping.Float[torch.Tensor, "d_llm d_llm"]:
    """Get the cosine similarity between the dictionary elements.

    Args:
        dict_elements_1: The first dictionary elements.
        dict_elements_2: The second dictionary elements.

    Returns:
        The cosine similarity between the dictionary elements.
    """
    # Compute cosine similarity in pytorch
    dict_elements_1 = dict_elements_1
    dict_elements_2 = dict_elements_2

    # Normalize the tensors
    if normalized:
        dict_elements_1 = torch.nn.functional.normalize(dict_elements_1, p=p, dim=dim)
        dict_elements_2 = torch.nn.functional.normalize(dict_elements_2, p=p, dim=dim)

    # Compute cosine similarity using matrix multiplication
    cosine_sim: jaxtyping.Float[torch.Tensor, "d_llm d_llm"] = torch.mm(
        dict_elements_1, dict_elements_2.T
    )
    # max_cosine_sim, _ = torch.max(cosine_sim, dim=1)
    return cosine_sim

In [65]:
cos_sim = get_cosine_similarity(sae1.W_dec, sae2.W_dec)


In [70]:
values, indices = cos_sim.max(1)

In [99]:
indices = torch.where(cos_sim.fill_diagonal_(0) >0.9)
print(indices)

(tensor([    5,   458,  1453,  1992,  2020,  2278,  2463,  2914,  2979,  4667,
         4670,  4964,  5462,  5950,  6517,  6564,  6682,  6847,  7228,  7352,
         7415,  8234,  8304, 11573, 11772, 12553, 13534, 14263, 14278, 14568,
        16004, 17103, 17192, 18215, 18283, 18353, 18833, 19410, 19603, 19627,
        20284, 20570, 21315, 21357, 21660, 21793, 22545, 22865, 23125, 23302,
        23404, 23548, 23835, 23969, 24259, 25613, 25900, 26452, 26511, 26511,
        26793, 26824, 27340, 27882, 29236, 29325, 29326, 29463, 29813, 29921,
        31394, 32139, 32164], device='cuda:0'), tensor([ 9124,  2278, 18833,  5950, 29813,   458, 11573, 13534, 26824, 11772,
        23404, 27340, 21315,  1992, 21357,  7415, 14263, 23835, 24259, 26452,
         6564, 32139, 14278,  2463,  4667, 18353,  2914,  6682,  8304, 32164,
        29463, 23548, 22865, 19603, 20570, 12553,  1453, 29921, 18215, 23969,
        29326, 18283,  5462,  6517, 26511, 23125, 25613, 17192, 21793, 26793,
         4670, 

In [100]:
cos_sim[indices[0][0], indices[1][0]]

tensor(0.9009, device='cuda:0')

tensor([[ 7614],
        [13276],
        [ 6193],
        [31791],
        [ 8700],
        [ 2773],
        [ 1576],
        [ 6794],
        [ 2412],
        [ 1303]], device='cuda:0')

In [123]:
model_name = "EleutherAI/pythia-70m-deduped"
layers = 6
release = "pythia-70m-deduped-res-sm"
sae_id = "blocks.0.hook_resid_post"
sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
diffs = []
for idx in tqdm(range(dataset_length*100)):
    example = dataset[idx]
    tokens = model.to_tokens([example[text]], prepend_bos=True)
    loss1, cache1 = model.run_with_cache_with_saes(
        tokens, saes=sae1, use_error_term=False
    )
    diffs.append(
        (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"][:, :, indices[0][0]]
            - cache1["blocks.0.hook_resid_post.hook_sae_acts_post"][:, :, indices[1][0]]
        ).abs().sum()
    )
    model.reset_saes()

100%|██████████| 3600/3600 [00:53<00:00, 67.32it/s]


In [124]:
count = 0
for diff in diffs:
    if diff != 0:
        count += 1
print(f"Count: {count}")

Count: 151


In [115]:
mask = (cos_sim.fill_diagonal_(-100) < 0.1) & (cos_sim.fill_diagonal_(-100) > -0.1)
min_indices = torch.where(mask)


In [116]:
min_indices

(tensor([    0,     0,     0,  ..., 32767, 32767, 32767], device='cuda:0'),
 tensor([    1,     2,     3,  ..., 32762, 32763, 32766], device='cuda:0'))

In [125]:
model_name = "EleutherAI/pythia-70m-deduped"
layers = 6
release = "pythia-70m-deduped-res-sm"
sae_id = "blocks.0.hook_resid_post"
sae1 = sae_lens.SAE.from_pretrained(release, sae_id, device="cuda")[0]
diffs = []
for idx in tqdm(range(dataset_length*100)):
    example = dataset[idx]
    tokens = model.to_tokens([example[text]], prepend_bos=True)
    loss1, cache1 = model.run_with_cache_with_saes(
        tokens, saes=sae1, use_error_term=False
    )
    diffs.append(
        (
            cache1["blocks.0.hook_resid_post.hook_sae_acts_post"][:, :, min_indices[0][0]]
            - cache1["blocks.0.hook_resid_post.hook_sae_acts_post"][:, :, min_indices[1][0]]
        ).abs().sum()
    )
    model.reset_saes()

100%|██████████| 3600/3600 [00:57<00:00, 62.91it/s]


In [126]:
count = 0
for diff in diffs:
    if diff != 0:
        count += 1
print(f"Count: {count}")

Count: 139
