# Evaluating your SAE

## Set Up

In [None]:
import os
import sys
import torch
import wandb
import json
import plotly.express as px
from transformer_lens import utils
from datasets import load_dataset
from typing import  Dict
from pathlib import Path

from functools import partial

from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_analysis.visualizer import data_fns, html_fns
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

if torch.backends.mps.is_available():
    device = "mps" 
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()
    

# Load your Autoencoder



In [None]:
from sae_training.sparse_autoencoder import SparseAutoencoder
# Load model from Huggingface
# run = wandb.init()
# artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2', type='model')
# artifact_dir = artifact.download()

# Load in Model
path = "checkpoints/bu20al09/lilac_plant_final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
model, sparse_autoencoder_10M, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)
path = "overnight_sae_resid_pre_10_gpt_2_small.pt"
sparse_autoencoder_200M = SparseAutoencoder.load_from_pretrained(path)


## Test the Autoencoder

### L0 Test and Reconstruction Test

In [None]:
with torch.no_grad():
    batch_tokens = activations_loader.get_batch_tokens()
    print(batch_tokens.shape)
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    activations =  cache[sparse_autoencoder_10M.cfg.hook_point]
    
    sae_out_10M, feature_acts_10M, loss, mse_loss, l1_loss = sparse_autoencoder_10M(
        cache[sparse_autoencoder_10M.cfg.hook_point]
    )
    # del cache
    
    l2_norms_of_input = torch.norm(activations[:,1:], dim=-1)
    l2_norms_of_sae_out = torch.norm(sae_out_10M[:,1:], dim=-1)
    print("l2_norms_of_input", l2_norms_of_input.mean().item())
    print("l2_norms_of_sae_out", l2_norms_of_sae_out.mean().item())
    
    l0 = (feature_acts_10M > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

    sae_out_100M, feature_acts_100M, loss, mse_loss, l1_loss = sparse_autoencoder_200M(
        cache[sparse_autoencoder_200M.cfg.hook_point]
    )
    # del cache
    
    l2_norms_of_input = torch.norm(activations[:,1:], dim=-1)
    l2_norms_of_sae_out = torch.norm(sae_out_100M[:,1:], dim=-1)
    print("l2_norms_of_input", l2_norms_of_input.mean().item())
    print("l2_norms_of_sae_out", l2_norms_of_sae_out.mean().item())
    
    l0 = (feature_acts_100M > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()


# Monday Stuff

In [None]:

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out

def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

with torch.no_grad():
    print("Orig", model(batch_tokens, return_type="loss").item())
    print(
        "reconstr",
        model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (
                    utils.get_act_name("resid_pre", 10),
                    partial(reconstr_hook, new_mlp_out=sae_out_10M),
                )
            ],
            return_type="loss",
        ).item(),
    )
    print(
        "reconstr",
        model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (
                    utils.get_act_name("resid_pre", 10),
                    partial(reconstr_hook, new_mlp_out=sae_out_100M),
                )
            ],
            return_type="loss",
        ).item(),
    )
    print(
        "Zero",
        model.run_with_hooks(
            batch_tokens,
            return_type="loss",
            fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
        ).item(),
    )

In [None]:
# # cache.apply_ln_to_stack(x_reconstruct[0],layer=10).mean()
# example_batch = torch.randint(0,32,(1,)).item(); example_position = torch.randint(0, 10, (1,)).item()
# print(example_batch, example_position)
# print(model.to_str_tokens(batch_tokens[example_batch])[max(example_position-5,0):min(example_position+3,128)])
# px.line(feature_acts[example_batch,example_position].cpu().numpy()).show()
# lnd_activations = cache.apply_ln_to_stack(activations, layer=10)
# _, feature_acts_after_ln, _, _, _ = sparse_autoencoder(
#         lnd_activations
#     )
# px.line(feature_acts_after_ln[example_batch,example_position].cpu().numpy()).show()
# vals, inds = torch.topk(feature_acts[example_batch,example_position].detach(), 10)
# px.bar(
#     x=utils.to_numpy(vals),
#     y=[str(i.item()) for i in inds],
#     orientation="h",
# ).show()
# utils.test_prompt(
#     prompt = model.to_string(batch_tokens[example_batch][1:example_position+1]),
#     answer = model.to_string(batch_tokens[example_batch][example_position+1]),
#     model = model)


import pandas as pd

def plot_feature_unembed_bar(feature_id, sparse_autoencoder, feature_name = ""):
    
    # norm_unembed = model.W_U / model.W_U.norm(dim=0)[None: None]
    # feature_unembed = sparse_autoencoder.W_dec[feature_id] @ norm_unembed
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @  model.W_U
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
                 color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    # fig.write_image(f"figures/{str(feature_id)}_{feature_name}.png")
    fig.show()


plot_feature_unembed_bar(14076, sparse_autoencoder_200M, feature_name = str(14076))
# for i in inds:
#     plot_feature_unembed_bar(int(i), sparse_autoencoder, feature_name = str(i.item()))

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [None]:
example_prompt = "When Chris and David went to the play, David handed a club to"
example_answer = " Chris"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder_10M(
    cache[sparse_autoencoder_10M.cfg.hook_point]
)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out

def reconstr_key_hook(mlp_out, hook, reconstructed_key):
    return reconstructed_key

def reconstr_query_hook(mlp_out, hook, reconstructed_query):
    return reconstructed_query


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, new_mlp_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            utils.get_act_name("resid_pre", 10),
            partial(reconstr_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
print((feature_acts[0,-1].detach() > 0 ).float().sum())
px.line(feature_acts[0,-1].detach().cpu().numpy()).show()
vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
post_ln_activation = cache.apply_ln_to_stack(cache[sparse_autoencoder_10M.cfg.hook_point], layer=10)
_, feature_acts_post_ln_activations, _, _, _ = sparse_autoencoder_10M(
    post_ln_activation
)

print((feature_acts_post_ln_activations[0,-1].detach() > 0 ).float().sum())
px.line(feature_acts_post_ln_activations[0,-1].detach().cpu().numpy()).show()
vals, inds = torch.topk(feature_acts_post_ln_activations[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits_original, cache_original = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder_10M(
    cache_original[sparse_autoencoder_10M.cfg.hook_point]
)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    print(mlp_out.shape, new_mlp_out.shape)
    mlp_out[:,-1,:] = new_mlp_out[:,-1,:]
    return mlp_out

def mean_ablation_hook(mlp_out, hook, new_mlp_out):
    print(mlp_out.shape, new_mlp_out.shape)
    mlp_out[:,-1,:] = new_mlp_out[:,-1,:].mean(dim=1)
    return mlp_out

def reconstr_key_hook(mlp_out, hook, reconstructed_key):
    return reconstructed_key

def reconstr_query_hook(mlp_out, hook, reconstructed_query):
    return reconstructed_query

model.reset_hooks()
with model.hooks(
    fwd_hooks=[
        (
            utils.get_act_name("resid_pre", 10),
            partial(mean_ablation_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
    logits_reconstructed, cache_reconstructed_res_stream = model.run_with_cache(example_prompt, prepend_bos=True)

In [None]:
from circuitsvis.attention import attention_patterns
patterns = cache_original["blocks.10.attn.hook_pattern"][0]
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns)

In [None]:
patterns = cache_reconstructed_res_stream["blocks.10.attn.hook_pattern"][0]
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns)

In [None]:
chris_token = model.to_single_token(" John")
david_token = model.to_single_token(" Mary")
chris_david_dir = model.W_U[:,chris_token] - model.W_U[:,david_token]
print(chris_david_dir.shape)

In [None]:
cache_original[sparse_autoencoder_10M.cfg.hook_point][0, -1] @ chris_david_dir

In [None]:
sae_out[0, -1] @ chris_david_dir

In [None]:
logits_original[0, -1, chris_token] - logits_original[0, -1, david_token]

In [None]:
logits_reconstructed[0, -1, chris_token] - logits_reconstructed[0, -1, david_token]

To do:
- do a basic decomp of qk with john/mary
- make the reverse ioi example and look at whether it points to specific features. 

In [None]:
px.imshow(cache_original["blocks.10.attn.hook_attn_scores"][0,7].detach().cpu())

In [None]:
px.imshow(cache_reconstructed_res_stream["blocks.10.attn.hook_attn_scores"][0,7].detach().cpu())

In [None]:
cache_reconstructed_res_stream["blocks.10.attn.hook_attn_scores"][0,7,-1,4]

In [None]:
sae_out_keys.shape

In [None]:
import numpy as np
(sae_out_keys.T @ model.blocks[10].attn.QK[7] @ sae_out_query) # / np.sqrt(model.cfg.d_head)

In [None]:
example = 0
sae_out_keys = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,4])[0]
sae_out_query = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,-1])[0]
key_resid_pre_feature_acts = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,4])[1]
query_resid_pre_feature_acts = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,-1])[1]


firing_key_features = (key_resid_pre_feature_acts > 0).float()
firing_query_features = (query_resid_pre_feature_acts > 0).float()
indices_keys = torch.nonzero(firing_key_features).flatten()
print(indices_keys)
indices_queries = torch.nonzero(firing_query_features).flatten()
print(indices_queries)

decomposed_keys = key_resid_pre_feature_acts[indices_keys, None] * sparse_autoencoder_200M.W_dec[indices_keys]
decomposed_queries = query_resid_pre_feature_acts[indices_queries, None] * sparse_autoencoder_200M.W_dec[indices_queries]
print(decomposed_keys.shape, decomposed_queries.shape)

sae_qk_circuit_instance = decomposed_keys @ model.blocks[10].attn.QK[7] @ decomposed_queries.T
sae_qk_circuit_instance = sae_qk_circuit_instance.AB.detach().cpu()
print(sae_qk_circuit_instance.shape)
print((sae_qk_circuit_instance > 0).sum())
print("Score Sum", sae_qk_circuit_instance.sum().item())

fig = px.imshow(sae_qk_circuit_instance.numpy(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

# add xtick and y tick labels with the key and query indices
fig.update_xaxes(
    tickvals=list(range(len(indices_queries))),
    ticktext=[str(i.item()) for i in indices_queries],
)
fig.update_yaxes(
    tickvals=list(range(len(indices_keys))),
    ticktext=[str(i.item()) for i in indices_keys],
)

fig.show()

px.bar(x=[str(i.item()) for i in indices_queries],  y = sae_qk_circuit_instance.sum(0).numpy()).show()
px.bar(x=[str(i.item()) for i in indices_keys],  y = sae_qk_circuit_instance.sum(1).numpy()).show()



In [None]:
example = 0
sae_out_keys = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,4])[0]
sae_out_query = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,-1])[0]
key_resid_pre_feature_acts = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,4])[1]
query_resid_pre_feature_acts = sparse_autoencoder_200M(cache_original["blocks.10.hook_resid_pre"][example,-1])[1]

firing_key_features = (key_resid_pre_feature_acts > 0).float()
firing_query_features = (query_resid_pre_feature_acts > 0).float()
indices_keys = torch.nonzero(firing_key_features).flatten()
print(indices_keys)
indices_queries = torch.nonzero(firing_query_features).flatten()
print(indices_queries)

decomposed_keys = key_resid_pre_feature_acts[indices_keys, None] * sparse_autoencoder_200M.W_dec[indices_keys]
decomposed_queries = query_resid_pre_feature_acts[indices_queries, None] * sparse_autoencoder_200M.W_dec[indices_queries]
print(decomposed_keys.shape, decomposed_queries.shape)



sae_qk_circuit_instance = decomposed_keys @ model.blocks[10].attn.QK[7] @ decomposed_queries.T
sae_qk_circuit_instance = sae_qk_circuit_instance.AB.detach().cpu()
print(sae_qk_circuit_instance.shape)
print((sae_qk_circuit_instance > 0).sum())

fig = px.imshow(sae_qk_circuit_instance.numpy(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

# add xtick and y tick labels with the key and query indices
fig.update_xaxes(
    tickvals=list(range(len(indices_queries))),
    ticktext=[str(i.item()) for i in indices_queries],
)
fig.update_yaxes(
    tickvals=list(range(len(indices_keys))),
    ticktext=[str(i.item()) for i in indices_keys],
)

fig.show()

px.bar(x=[str(i.item()) for i in indices_queries],  y = sae_qk_circuit_instance.sum(0).numpy()).show()
px.bar(x=[str(i.item()) for i in indices_keys],  y = sae_qk_circuit_instance.sum(1).numpy()).show()


values, indices = torch.topk(sae_qk_circuit_instance.AB.detach().flatten(),25)
d_enc = sparse_autoencoder_200M.cfg.d_sae
start_topk_ind = (indices // d_enc)
end_topk_ind = (indices % d_enc)


## Generating Feature Interfaces

In [None]:
vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)
        

os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset("NeelNanda/c4-code-20k", split="train") # currently use this dataset to avoid deal with tokenization while streaming
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]


# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096*5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder_200M,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder_200M.cfg.hook_point,
    hook_point_layer=sparse_autoencoder_200M.cfg.hook_point_layer,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)


for test_idx in list(inds.flatten().cpu().numpy()):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

This will produce a number of html files which each contain a dashboard showing feature activation on the sample data. It currently doesn't process that much data so it isn't that useful. 

# Tuesday Stuff

In [None]:
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")

In [None]:

example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:


logits_original, cache_original = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder_200M(
    cache_original[sparse_autoencoder_200M.cfg.hook_point]
)


def reconstr_hook(mlp_out, hook, new_mlp_out):
    # print(mlp_out.shape, new_mlp_out.shape)
    return new_mlp_out

model.reset_hooks()
with model.hooks(
    fwd_hooks=[
        (
            utils.get_act_name("resid_pre", 10),
            partial(reconstr_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
    logits_reconstructed, cache_reconstructed_res_stream = model.run_with_cache(example_prompt, prepend_bos=True)

In [None]:

reconstructed_q = cache_reconstructed_res_stream["blocks.10.attn.hook_q"].detach()
reconstructed_k = cache_reconstructed_res_stream["blocks.10.attn.hook_k"].detach()


def reconstr_key_hook(key, hook, reconstructed_key):
    return reconstructed_key

def reconstr_query_hook(query, hook, reconstructed_query):
    print("reconstr_query_hook", query.shape, reconstructed_query.shape)
    if query.shape == reconstructed_query.shape:
        return reconstructed_query
    else:
        new_query = torch.concat(
            [query[:,0].unsqueeze(1), reconstructed_query],
            dim=1
        )
        return new_query

model.reset_hooks()
with model.hooks(
    fwd_hooks=[
        (
            "blocks.10.attn.hook_q",
            partial(reconstr_query_hook, reconstructed_query=reconstructed_q),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
    _, cache_reconstructed_queries = model.run_with_cache(example_prompt, prepend_bos=True)

model.reset_hooks()

In [None]:
reconstructed_k = cache_reconstructed_res_stream["blocks.10.attn.hook_k"].detach()
    
def reconstr_key_hook(key, hook, reconstructed_key):
    print("reconstr_key_hook", key.shape, reconstructed_key.shape)
    if key.shape == reconstructed_key.shape:
        return reconstructed_key
    else:
        new_key = torch.concat(
            [key[:,0].unsqueeze(1), reconstructed_key],
            dim=1
        )
        return new_key

model.reset_hooks()
with model.hooks(
    fwd_hooks=[
        (
            "blocks.10.attn.hook_k",
            partial(reconstr_key_hook, reconstructed_key=reconstructed_k),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
    _, cache_reconstructed_keys = model.run_with_cache(example_prompt, prepend_bos=True)

model.reset_hooks()

### Get the MSE Loss by Layer

In [None]:
def get_mse_loss_cache_df(original_cache, intervention_cache):
    mse_loss = lambda x, y: (x - y).pow(2).mean()
    keys = []
    values = []
    for key in cache_original.keys():
        keys.append(key)
        values.append(mse_loss(original_cache[key], intervention_cache[key]).item())
    df = pd.DataFrame({"key": keys, "mse_loss": values})
    
    # get the index of the first non-zero mse_loss
    first_non_zero_idx = df[df["mse_loss"] > 0].index[0]
    # filter from there onward
    df = df.iloc[first_non_zero_idx:]
    return df

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_res_stream)
px.line(df, x="key", y="mse_loss").show()

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_queries)
px.line(df, x="key", y="mse_loss").show()

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_keys)
px.line(df, x="key", y="mse_loss").show()

### Visualize attn patterns

In [None]:
from circuitsvis.attention import attention_patterns
patterns_original = cache_original["blocks.10.attn.hook_pattern"][0]
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns_original)

In [None]:
# patterns_reconstructed = cache_reconstructed_res_stream["blocks.10.attn.hook_pattern"][0].detach().cpu()
# patterns_reconstructed = cache_reconstructed_keys["blocks.10.attn.hook_pattern"][0].detach().cpu()
patterns_reconstructed = cache_reconstructed_queries["blocks.10.attn.hook_pattern"][0].detach().cpu()
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns_reconstructed)

### Visualize change in Attn Scores/Patterns

In [None]:
intervention_cache = cache_reconstructed_keys
scores_original = cache_original["blocks.10.attn.hook_attn_scores"][0].detach().cpu()
scores_reconstructed = intervention_cache["blocks.10.attn.hook_attn_scores"][0].detach().cpu()
patterns_original = cache_original["blocks.10.attn.hook_pattern"][0].detach().cpu()
patterns_reconstructed = intervention_cache["blocks.10.attn.hook_pattern"][0].detach().cpu()


import pandas as pd
import itertools
import numpy as np

def tensor_to_long_data_frame(tensor_result, dimension_names, value_name = "Score"):
    assert len(tensor_result.shape) == len(
        dimension_names
    ), "The number of dimension names must match the number of dimensions in the tensor"

    tensor_2d = tensor_result.reshape(-1).detach().cpu()
    df = pd.DataFrame(tensor_2d.detach().numpy(), columns=[value_name])

    indices = pd.MultiIndex.from_tuples(
        list(np.ndindex(tensor_result.shape)),
        names=dimension_names,
    )
    df.index = indices
    
    
    df.reset_index(inplace=True)
    # set all dimensions except Score to categorical
    for i in range(len(dimension_names)):
        df[dimension_names[i]] = df[dimension_names[i]].astype("category")
    
    return df

scores_df = tensor_to_long_data_frame(scores_original, ["Head", "Query", "Key"], value_name = "Original")
scores_df["Reconstructed"] = scores_reconstructed.flatten().detach().cpu().numpy()
scores_df["Pattern Original"] = patterns_original.flatten().detach().cpu().numpy()
scores_df["Pattern Reconstructed"] = patterns_reconstructed.flatten().detach().cpu().numpy()
scores_df = scores_df[scores_df["Original"] != float("inf")]
scores_df = scores_df[scores_df["Reconstructed"] != float("inf")]
scores_df.head()


In [None]:
# set pandas default width/height
fig = px.scatter(scores_df, x="Original", y="Reconstructed", color = "Head", hover_data=["Head", "Query", "Key"])
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
px.scatter(scores_df, x="Pattern Original", y="Pattern Reconstructed", color = "Head", hover_data=["Head", "Query", "Key"], log_x=True, log_y=True)
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
px.scatter(scores_df, x="Pattern Original", y="Pattern Reconstructed", color = "Key", hover_data=["Head", "Query", "Key"], log_x=True, log_y=True)
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
import torch
import torch.nn.functional as F
from einops import einsum


def kl_divergence_attention(y_true, y_pred):

    # Compute log probabilities for KL divergence
    log_y_true = torch.log(y_true)
    log_y_pred = torch.log(y_pred)

    return y_true * (log_y_true - log_y_pred)


# Example usage
print(patterns_original.shape)
kl_result = kl_divergence_attention(patterns_original, patterns_reconstructed)
kl_result[kl_result.isnan()] = 0
fig = px.imshow(kl_result.sum(dim=-1).detach().cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu",
                labels = dict(x="Query", y="Head"), text_auto=".2f")
fig.layout.coloraxis.colorbar.title = "KL Divergence"
fig.update_layout(
    width=800,
    height=600,
)
fig.show()