In [None]:
from IPython import get_ipython
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

from transformer_lens import HookedTransformer, utils
import torch
from datasets import load_dataset
import time
import os
from typing import Optional, List, Dict, Callable, Tuple, Union
import tqdm.notebook as tqdm
from pathlib import Path
import pickle
import plotly.express as px
import importlib
import json

from sae_analysis.visualizer import data_fns, model_fns, html_fns
# from model_fns import AutoEncoderConfig, AutoEncoder
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

device = "mps"

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()

# Load SAE





In [None]:

from transformer_lens import HookedTransformer, utils
import torch
import torch as t
import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from functools import partial
from datasets import load_dataset
import numpy as np
from jaxtyping import Float
from transformer_lens import ActivationCache
from pathlib import Path
import torch.nn as nn
import pprint
import json 
import torch.nn.functional as F


DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self



# load gpt2-small
model = HookedTransformer.from_pretrained("gpt2-small").to('mps')



point, layer = "resid_pre", 10
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)


cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": "mps",
    "model_batch_size": 1028,
}
encoder_resid_pre_10 = AutoEncoder(cfg)
encoder_resid_pre_10.load_state_dict(dic)

point, layer = "resid_pre", 11
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)


cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": "mps",
    "model_batch_size": 1028,
}
encoder_resid_pre_11 = AutoEncoder(cfg)
encoder_resid_pre_11.load_state_dict(dic)

data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(22)

from transformer_lens import utils

example_prompt = "After Jack and Mary went to the store, Jack gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)


In [None]:
import os
vocab_dict = model.tokenizer.vocab
vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()}

vocab_dict_filepath = Path(os.getcwd()) / "vocab_dict.json"
if not vocab_dict_filepath.exists():
    with open(vocab_dict_filepath, "w") as f:
        json.dump(vocab_dict, f)

In [None]:
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

In [None]:
dic.keys()

In [None]:
from sae_training.sparse_autoencoder import SparseAutoencoder
from dataclasses import dataclass

@dataclass
class SparseAutoencoderConfig:
    d_sae: int
    d_in: int
    l1_coefficient: float
    dtype: str
    seed: int
    device: str
    model_batch_size: int
    hook_point: str = "blocks.10.hook_resid_pre"
    hook_point_layer: int = 10
    
cfg = {
    "d_sae": 6144,
    "d_in": 768,
    "l1_coefficient": 0.001,
    "dtype": torch.float32,
    "seed": 0,
    "device": "mps",
    "model_batch_size": 1028,
}


sparse_autoencoder_cfg = SparseAutoencoderConfig(**cfg)
sparse_autoencoder = SparseAutoencoder(sparse_autoencoder_cfg)


point, layer = "resid_pre", 10
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
sparse_autoencoder.load_state_dict(dic)

In [None]:
interesting_features = [
    1735, 1096, 1228, 5528, 1095, 4406, 4408, 848, 175, 4287,
    176, 1266, 5337, 1761, 2500, 2700, 4624, 1435, 4072, 872,
    3243, 1760, 4752, 3349, 5527, 4795
]


In [None]:

importlib.reload(data_fns)
importlib.reload(html_fns)
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

# Currently, don't think much more time can be squeezed out of it. Maybe the best saving would be to
# make the entire sequence indexing parallelized, but that's possibly not worth it right now.

max_batch_size = 512
total_batch_size = 4096*6
feature_idx = interesting_features
# max_batch_size = 512
# total_batch_size = 16384
# feature_idx = list(range(1000))

tokens = all_tokens[:total_batch_size]

feature_data: Dict[int, FeatureData] = get_feature_data(
    encoder=sparse_autoencoder,
    # encoder_B=sparse_autoencoder,
    model=model,
    hook_point=sparse_autoencoder.cfg.hook_point,
    hook_point_layer=sparse_autoencoder.cfg.hook_point_layer,
    tokens=tokens,
    feature_idx=feature_idx,
    max_batch_size=max_batch_size,
    left_hand_k = 3,
    buffer = (5, 5),
    n_groups = 10,
    first_group_size = 20,
    other_groups_size = 5,
    verbose = True,
)


for test_idx in list(interesting_features):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)

In [None]:

for test_idx in list(interesting_features):
    html_str = feature_data[test_idx].get_all_html()
    with open(f"data_{test_idx:04}.html", "w") as f:
        f.write(html_str)