In [None]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

from transformer_lens import HookedTransformer, utils
import torch
from datasets import load_dataset
import time
import os
from typing import Optional, List, Dict, Callable, Tuple, Union
import tqdm.notebook as tqdm
from pathlib import Path
import pickle
import plotly.express as px
import importlib
import json

from sae_analysis.visualizer import data_fns, model_fns, html_fns
# from model_fns import AutoEncoderConfig, AutoEncoder
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

device = "mps"

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()

# Load our autoencoder





In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2', type='model')
artifact_dir = artifact.download()

In [None]:
import torch
from functools import partial
from sae_training.utils import LMSparseAutoencoderSessionloader


# Load in Model
path ="artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v36/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576.pt"
# path ="artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2/1200001024_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)


# Test the Autoencoder

## L0 Test and Reconstruction Test

In [None]:

with torch.no_grad():
    batch_tokens = activations_loader.get_batch_tokens()
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    del cache
    
    
    l0 = feature_acts.sum(-1).mean().item()
    print(l0)
    px.histogram(feature_acts.sum(-1).flatten().cpu().numpy(), log_y=False).show()

    def reconstr_hook(mlp_out, hook, new_mlp_out):
        return new_mlp_out

    def zero_abl_hook(mlp_out, hook):
        return torch.zeros_like(mlp_out)

    print("Orig", model(batch_tokens, return_type="loss").item())
    print(
        "reconstr",
        model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (
                    utils.get_act_name("resid_pre", 10),
                    partial(reconstr_hook, new_mlp_out=sae_out),
                )
            ],
            return_type="loss",
        ).item(),
    )
    print(
        "Zero",
        model.run_with_hooks(
            batch_tokens,
            return_type="loss",
            fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
        ).item(),
    )

## Specific Capability Test

In [None]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)



logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                utils.get_act_name("resid_pre", 10),
                partial(reconstr_hook, new_mlp_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(utils.get_act_name("resid_pre", 10), zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            utils.get_act_name("resid_pre", 10),
            partial(reconstr_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
import torch
vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:

vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

In [None]:
importlib.reload(data_fns)
importlib.reload(html_fns)
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096*5
feature_idx = list(inds.flatten().cpu().numpy())
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)


for test_idx in list(inds.flatten().cpu().numpy()):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)