# Notebook for Generating Investigative Datasets for SAE's + Copy Suppression


Components:
- Generation <- get data + model and create token df (Neel Style)
- Calculation of non-SAE intervention (eg: Ablation)
- Calculation of SAE based interventions (eg: Reconstruction of Query)



# Set Up

In [None]:
import sys 
sys.path.append("..")

from importlib import reload

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

LAYER_IDX, HEAD_IDX = (10, 7)

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



# path = "checkpoints/ikig1wjm/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_32768.pkl"
# path="../artifacts/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v15/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096.pkl"#
# path="../artifacts/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v16/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096.pkl"
# path="../artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v56/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576.pkl"
# hacky solution to saved with cuda load on mps:
# path = "artifacts/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v13/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096.pkl"
# sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)


from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.config import LanguageModelSAERunnerConfig

# path = "checkpoints/peu1onjp/132669440_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_8192.pkl"
# path = "checkpoints/g2zrx9ho/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_8192.pkl"
path = "artifacts/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536:v28/1076002816_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pkl" 


with open(path, 'rb') as file:
    state_dict = CPU_Unpickler(file).load()

cfg = state_dict["cfg"].__dict__
cfg["device"] = "mps"
cfg["hook_point_layer"] = 10
del cfg["d_sae"]
del cfg["tokens_per_buffer"]
cfg = LanguageModelSAERunnerConfig(**cfg)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder.load_state_dict(state_dict["state_dict"])
del state_dict
del cfg

### Test Jacob's SAE

In [None]:
# from transformer_lens.utils import download_file_from_hf
# from dataclasses import dataclass
# point, layer = "resid_pre", 10
# dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
# # sparse_autoencoder.load_state_dict(dic)
# dic.keys()


# @dataclass
# class SparseAutoencoderConfig:
#     d_sae: int
#     d_in: int
#     l1_coefficient: float
#     dtype: str
#     seed: int
#     device: str
#     model_batch_size: int
#     hook_point: str = "blocks.10.hook_resid_pre"
#     hook_point_layer: int = 10
    
# cfg = {
#     "d_sae": 6144,
#     "d_in": 768,
#     "l1_coefficient": 0.001,
#     "dtype": torch.float32,
#     "seed": 0,
#     "device": "mps",
#     "model_batch_size": 1028,
# }

# sparse_autoencoder_cfg = SparseAutoencoderConfig(**cfg)
# sparse_autoencoder = SparseAutoencoder(sparse_autoencoder_cfg)


# point, layer = "resid_pre", 10
# dic = download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
# sparse_autoencoder.load_state_dict(dic)


## Define Hooks

# Main Dataset Generation Func

In [None]:

# Test on individual prompt (random repeating tokens)
random_tokens, random_token_groups = generate_random_token_prompt(model, n_random_tokens=LENGTH_RANDOM_TOKS, n_repeat_tokens=3, token_of_interest=" Mary")
prompt = model.to_string(random_tokens)
print(prompt)
token_df, original_cache, cache_reconstructed_query = eval_prompt([prompt], model, sparse_autoencoder)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence"]
token_df[filter_cols].style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm",
)


In [None]:
patterns_original = original_cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")
patterns_original = original_cache[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")

In [None]:
if not "resid_pre" in sparse_autoencoder.cfg.hook_point:
    original_act = original_cache[sparse_autoencoder.cfg.hook_point][:,:,HEAD_IDX]
    sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_cache[sparse_autoencoder.cfg.hook_point][:,:,HEAD_IDX])
else:
    original_act = original_cache[sparse_autoencoder.cfg.hook_point]
    sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)

POS_INTEREST = feature_acts.shape[1] - 1
plot_line_with_top_10_labels(feature_acts[0, POS_INTEREST], "", 30)

vals, inds = torch.topk(feature_acts[0,POS_INTEREST],10)

print(inds)
plot_attn_score_by_feature(model, sparse_autoencoder, inds, original_cache, token_df, pos_interest=POS_INTEREST, vals=vals)
plot_unembed_score_by_feature(model, sparse_autoencoder, inds, token_df, vals=vals)

In [None]:
# plot_feature_unembed_bar(46076, sparse_autoencoder, feature_name = "")
plot_qk_via_feature(model, 49077, sparse_autoencoder, feature_name = "", highlight_tokens=token_df.str_tokens.tolist())

## Realistic Example

In [None]:
prompt = "When John and Mary went to the shops, John gave the bag to"
answer = " Mary"
# prompt = "All's fair in love and"
# answer = " war"
# prompt = " The cat is cute. The dog is"
# prompt = " Alice, with her keen intelligence and artistic talent, discussed philosophy with Bob, who shared her intellect and also possessed remarkable culinary skills, while"
# answer = " cute"
model.reset_hooks()
utils.test_prompt(prompt, answer, model)

with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
    utils.test_prompt(prompt, answer, model)

In [None]:
token_df, original_cache, cache_reconstructed_query = eval_prompt([prompt + answer], model, sparse_autoencoder)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence"]
token_df[filter_cols].style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm")


In [None]:
patterns_original = original_cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")
patterns_original = original_cache[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")

In [None]:
original_act = original_cache[sparse_autoencoder.cfg.hook_point][:,:,HEAD_IDX]
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_cache[sparse_autoencoder.cfg.hook_point][:,:,HEAD_IDX])
POS_INTEREST = feature_acts.shape[1]-2
print(POS_INTEREST)
plot_line_with_top_10_labels(feature_acts[0, POS_INTEREST], "")

In [None]:
feature_name = "All features"
vals, inds = torch.topk(feature_acts[0,POS_INTEREST],10)
print(inds)
plot_attn_score_by_feature(model, sparse_autoencoder, inds, original_cache, token_df, POS_INTEREST)

In [None]:
plot_unembed_score_by_feature(model, sparse_autoencoder, inds, token_df, vals=vals)

In [None]:
plot_qk_via_feature(model, 23287, sparse_autoencoder, feature_name = "", highlight_tokens=token_df.str_tokens.tolist())

In [None]:
from pathlib import Path
import json
from typing import  Dict
from sae_analysis.visualizer import data_fns, html_fns
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

In [None]:

vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)
        

In [None]:
inds

In [None]:
from importlib import reload
from sae_analysis.visualizer import data_fns
reload(data_fns)

dataset="stas/openwebtext-10k"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# data = get_webtext()
raw_dataset = load_dataset(dataset)
train_dataset = raw_dataset["train"]
tokenized_data = utils.tokenize_and_concatenate(train_dataset, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 32
total_batch_size = 512 * 10
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = data_fns.get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer - 1,
    hook_point_head_index=sparse_autoencoder.cfg.hook_point_head_index,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)


for test_idx in list(inds.flatten().cpu().numpy()):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

## Run on whole dataset

In [None]:
torch.mps.empty_cache()

In [None]:
str_token_list = []
loss_list = []
ablated_loss_list = []

NUM_PROMPTS = 10
MAX_PROMPT_LEN = 100
# BATCH_SIZE = 10
dataframe_list = []
with torch.no_grad():
    for i in tqdm(range(NUM_PROMPTS)):
        
        # Get Token Data
        prompt = data[i]
        # new_str = data[BATCH_SIZE * i: BATCH_SIZE * (i + 1)]
        

        token_df, _, _= eval_prompt(prompt)
        dataframe_list.append(token_df)
        
df = pd.concat(dataframe_list)

print(df.shape)
print(df.columns)
df.head()

In [None]:
df.sort_values(by="loss_diff", ascending=True).head(10).style.background_gradient(cmap='viridis', subset=["loss_diff", "mse_loss"])

In [None]:
tmp = df.query("max_idx_tok == rec_q_max_idx_tok").query("max_idx_tok != '<|endoftext|>'")
print(df.shape)
print(tmp.shape[0])
tmp = tmp.sort_values("num_active_features", ascending=True).head(50)
tmp#.style.background_gradient(cmap='viridis', subset=["loss_diff", "num_active_features", "mse_loss", "kl_divergence", "q_norm", "rec_q_norm"])

In [None]:
px.scatter(tmp, x="num_active_features", y="loss_diff", hover_data=["max_idx_tok", "max_idx_tok_value"], marginal_x="histogram", marginal_y="histogram")

In [None]:
px.scatter(tmp, x="max_idx_tok_value", y="rec_q_max_idx_tok_value", hover_data=["max_idx_tok", "max_idx_tok_value"], marginal_x="histogram", marginal_y="histogram")

# Plot Results

In [None]:
print(df.shape)
px.histogram(df, x="loss_diff", nbins=100, log_y=False, title="Loss Difference (Ablated - Original)")

In [None]:
px.histogram(df, x="num_active_features", nbins=100, title="Loss Difference (Ablated - Original)")

In [None]:
px.scatter(df,
           marginal_x="histogram",
           marginal_y="histogram",
           x="num_active_features", y="mse_loss", 
           log_y=True,
           log_x=True,
           title="Query Reconstuction Loss (MSE) vs. Number of Active Features")

# Drill Down

In [None]:
vals, inds = torch.topk(feature_acts[0,POS_INTEREST],10)
tok_of_interest = token_df["str_tokens"][POS_INTEREST+1]
print(tok_of_interest)
tok_id = model.tokenizer.encode(tok_of_interest)[0]
projection_love = sparse_autoencoder.W_dec[inds] @ model.W_Q[10,7].T @  model.W_U[:,tok_id]
# projection_love = sparse_autoencoder.W_enc[:,inds].T @ model.W_Q[10,7].T @  model.W_U[:,love_id]
inds = inds.tolist()

print(projection_love.shape)
df = pd.DataFrame(dict(
    projection_love=projection_love.detach().cpu().numpy(),
    score_contributions = score_contributions[:,POS_INTEREST].detach().cpu().numpy() - score_contributions[:,0].detach().cpu().numpy(),
    feature_id=inds,
    activation=vals.detach().cpu().numpy(),
    rank=range(len(inds)),
))

fig = px.scatter(
    df,
    x="projection_love",
    y = "score_contributions",
    # color="feature_id",
    # color_continuous_scale="RdBu",
    color_continuous_midpoint=0,
    color="activation",
    hover_data=["feature_id"],
    labels=dict(projection_love="Feature-Token Unembed Proj", y="Activation", score_contributions="Attention Score Contribution"),
    template="plotly",
)

# add a black border around all points
fig.update_traces(marker=dict(line=dict(width=1, color="Black")))
# make all points slightly larger
fig.update_traces(marker=dict(size=12))
# increase font size
fig.update_layout(font=dict(size=18))
fig.update_layout(
    width=800,
    height=800,
)
fig.show()


In [None]:

def plot_qk_via_feature(feature_id, sparse_autoencoder, feature_name = "", highlight_tokens = []):
    eff_embed = model.W_E + model.blocks[0].mlp(model.blocks[0].ln2(model.W_E[None] + model.blocks[0].attn.b_O))
    eff_emb_in_key_space =  eff_embed @ model.W_K[LAYER_IDX,HEAD_IDX] @ sparse_autoencoder.W_dec[feature_id]
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @ model.W_Q[LAYER_IDX,HEAD_IDX].T @  model.W_U
    # feature_unembed = sparse_autoencoder.W_enc[:,feature_id] @ model.W_Q[LAYER_IDX,HEAD_IDX].T @  model.W_U
    
    df = pd.DataFrame(dict(
        eff_emb_in_key_space=eff_emb_in_key_space[0].detach().cpu().numpy(),
        feature_unembed = feature_unembed.detach().cpu().numpy(),
        token = [model.tokenizer.decode(i) for i in range(50257)],
    ))
    
    df["token_of_interest"] = df["token"].isin(highlight_tokens)
    
    # add a column to df with text for the largest 10 values (positive and negative) 
    # that we can use to label these points
    top_10_key = df.sort_values("eff_emb_in_key_space", ascending=False).head(6)
    top_10_proj = df.sort_values("feature_unembed", ascending=False).head(6)
    
    top_10_key["text"] = top_10_key.apply(lambda x: f"{x['token']}", axis=1)
    top_10_proj["text"] = top_10_proj.apply(lambda x: f"{x['token']}", axis=1)
    
    # Merging the top and bottom points for annotation
    points_to_annotate = pd.concat([top_10_key, top_10_proj])

    fig = px.scatter(
        df,
        x="eff_emb_in_key_space",
        y = "feature_unembed",
        color="token_of_interest",
        color_continuous_scale="RdBu",
        # color="score_contributions",
        # text="text",
        # opacity=0.3,
        hover_data=["token"],
        labels=dict(eff_emb_in_key_space="Token to Feature Virtual Weight", feature_unembed="Unembed to Feature Virtual Weight"),
        title=f"Feature {feature_id} {feature_name}",
        template="plotly",
        marginal_x="histogram",
        marginal_y="histogram",
    )
    

    for _, row in points_to_annotate.iterrows():
        fig.add_annotation(x=row['eff_emb_in_key_space'], y=row['feature_unembed'],
                           text=row['text'], showarrow=False, arrowhead=1,
                           ax=20, ay=-40)

    
    fig.update_layout(
        width=1200,
        height=1200,
    )
    fig.show()
    
# plot_qk_via_feature(inds[0], sparse_autoencoder, feature_name = "What's not to suppress here?")
plot_qk_via_feature(3985, sparse_autoencoder, feature_name = "", highlight_tokens=model.to_str_tokens(tokens[0,1:]))
# plot_qk_via_feature(1102, sparse_autoencoder, feature_name = "", highlight_tokens=model.to_str_tokens(tokens[0,1:]))
# plot_qk_via_feature(1664, sparse_autoencoder, feature_name = "", highlight_tokens=model.to_str_tokens(tokens[0,1:]))
# plot_qk_via_feature(3017, sparse_autoencoder, feature_name = "", highlight_tokens=model.to_str_tokens(tokens[0,1:]))
# plot_qk_via_feature(2282, sparse_autoencoder, feature_name = "", highlight_tokens=model.to_str_tokens(tokens[0,1:]))

In [None]:
plot_qk_via_feature(inds[1], sparse_autoencoder, feature_name = "")

In [None]:
plot_qk_via_feature(2433, sparse_autoencoder, feature_name = "")

In [None]:
plot_qk_via_feature(2688, sparse_autoencoder, feature_name = "")

In [None]:
# get the key proj via the

eff_embed.shape


def plot_proj_onto_embed_via_key(feature_id, sparse_autoencoder, feature_name = ""):
    
    eff_embed = model.W_E + model.blocks[0].mlp(model.blocks[0].ln2(model.W_E[None] + model.blocks[0].attn.b_O))
    
    eff_emb_in_key_space =  eff_embed @ model.W_K[LAYER_IDX,HEAD_IDX] @ sparse_autoencoder.W_dec[feature_id]

    feature_unembed_df = pd.DataFrame(
        eff_emb_in_key_space.T.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
                 color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    # fig.write_image(f"figures/{str(feature_id)}_{feature_name}.png")
    fig.show()
plot_proj_onto_embed_via_key(2688, sparse_autoencoder, feature_name = "Famous People you Love")
plot_feature_unembed_bar(2688, sparse_autoencoder, feature_name = "Famous People you Love")