Part of this code is borrowed from an [SAELens tutorial](https://github.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb), and some of it is borrowed from Connor Kissane's attention-output-saes [repository](https://github.com/ckkissane/attention-output-saes). Thank you to Connor and to the contributors of SAELens for their contributions and for making their code public!

### Setup

In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch

torch.set_grad_enabled(False);

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device = device)

Loaded pretrained model gpt2-small into HookedTransformer


### Loading and pre-processing the SAE weights

In the block_diag_sae [branch](https://github.com/shehper/SAELens/tree/block_diag_sae) of SAELens used to train these models, an SAE did not have `W_dec` or `W_enc` as its attributes. Instead it had `dec_blocks` and `enc_blocks`, which were the Linear layers that acted on individual heads.

The code in this section, defines a new SAE with standard architectue that has block-diagonal encoder and decoder weights. In the process, we
- normalize the dictionary vectors to have unit norm.
- rescale weights by `norm_scaling_factor` as the SAEs were trained with normalized activations. (Specifically we set `cfg.normalize_activations="expected_average_only_in"`). SAELens has a `fold_activation_norm_scaling_factor` function to fold the overall scaling factor into the weights of a trained SAE. 

In [18]:
ckpt_dir = "checkpoints/d543gzxy/final_204800000/"
# W&B: https://wandb.ai/shehper/gpt2-small-attn-4-sae/runs/pumu7rz3?nw=nwusershehper

# ckpt_dir = "checkpoints/5eu9598y/final_409600000/"
# # W&B: https://wandb.ai/shehper/gpt2-small-attn-5-sae/runs/s4om7ilc?nw=nwusershehper

device = "cuda"
sae = SAE.load_from_pretrained(path=ckpt_dir,
                                    device=device)
orig_architecure = sae.cfg.architecture

sae.W_dec = sae.get_W_dec()
sae.b_dec = sae.get_b_dec()
sae.W_enc = sae.get_W_enc()
sae.b_enc = sae.get_b_enc()

dec_norms = sae.W_dec.norm(dim=-1)
sae.W_enc *= dec_norms
sae.b_enc *= dec_norms
sae.W_dec /= dec_norms[:, None]

In [5]:
from sae_lens import ActivationsStore

activations_store = ActivationsStore.from_sae(
    model = model,
    sae = sae,
    streaming=True,
    store_batch_size_prompts=8,
    n_batches_in_buffer=8,
)





In [6]:
from tqdm import tqdm 

def get_tokens(
    activations_store: ActivationsStore,
    n_batches_to_sample_from: int = 4096 * 6,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activations_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]

# 1000 prompts is plenty for a demo.
token_dataset = get_tokens(activations_store, 128, 128)

  0%|          | 0/128 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 128/128 [00:03<00:00, 38.75it/s]


In [7]:
# Some SAEs will require we estimate the activation norm and fold it into the weights. This is easy with SAE Lens. 
if sae.cfg.normalize_activations:
    norm_scaling_factor = activations_store.estimate_norm_scaling_factor(n_batches_for_norm_estimate=30)
    sae.fold_activation_norm_scaling_factor(norm_scaling_factor)

Estimating norm scaling factor: 100%|██████████| 30/30 [00:03<00:00,  8.53it/s]


In [9]:
from collections import OrderedDict
sae_sd = OrderedDict()

# sae_sd = sae.state_dict()
sae_sd["W_dec"] = sae.W_dec
sae_sd["b_dec"] = sae.b_dec
sae_sd["W_enc"] = sae.W_enc 
sae_sd["b_enc"] = sae.b_enc

new_cfg = sae.cfg
new_cfg.architecture = "standard"

new_sae = SAE(cfg=new_cfg)
new_sae.load_state_dict(state_dict=sae_sd)
del sae
sae = new_sae

### Load Text Dataset

The code for loading text dataset is borrowed from Connor Kissane's [attention-output-saes](https://github.com/ckkissane/attention-output-saes) repository.

The second cell will take ~5 minutes to run.

In [14]:
import einops

def get_batch_tokens(dataset_iter, batch_size, model):
    tokens = []
    total_tokens = 0
    seq_len = 1024
    while total_tokens < batch_size*seq_len:
        try:
            # Retrieve next item from iterator
            row = next(dataset_iter)["text"]
        except StopIteration:
            # Break the loop if dataset ends
            break
        
        # Tokenize the text with a check for max_length
        cur_toks = model.to_tokens(row)
        tokens.append(cur_toks)
        
        total_tokens += cur_toks.numel()

    # Check if any tokens were collected
    if not tokens:
        return None

    # Depending on your model's tokenization, you might need to pad the tokens here

    # Flatten the list of tokens
    flat_tokens = torch.cat(tokens, dim=-1).flatten()
    flat_tokens = flat_tokens[:batch_size * seq_len]
    reshaped_tokens = einops.rearrange(
        flat_tokens,
        "(batch seq_len) -> batch seq_len",
        batch=batch_size,
    )
    reshaped_tokens[:, 0] = model.tokenizer.bos_token_id
    return reshaped_tokens

# def shuffle_data(all_tokens):
#     print("Shuffled data")
#     return all_tokens[torch.randperm(all_tokens.shape[0])]

In [15]:
dataset = load_dataset(
    path = "Skylion007/openwebtext",
    split="train",
    streaming=True,
)


data_dir = "/home/ubuntu/storage/data"
os.makedirs(data_dir, exist_ok=True)

dataset_iter = iter(dataset)
num_tokens = 5e7
seq_len = 1024
all_tokens_batches = int(num_tokens) // seq_len
try:
    print("Loading cached data from disk")
    all_tokens = torch.load(f"{data_dir}/owt_tokens_reshaped.pt")
    # all_tokens = shuffle_data(all_tokens)
    print(all_tokens.shape)
except:
    print("Data was not cached: Loading data first time")
    all_tokens = get_batch_tokens(dataset_iter, all_tokens_batches, model)
    torch.save(all_tokens, f"{data_dir}/owt_tokens_reshaped.pt")
    print("all_tokens.shape", all_tokens.shape)

Loading cached data from disk
torch.Size([48828, 1024])


### Feature Dashboards

In [16]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

# test_feature_idx_gpt = [2048 * i + j for i in [1, 5, 6] for j in range(10)]
test_feature_idx_gpt = [2048 * i + j for i in range(12) for j in range(10)]

hook_name = sae.cfg.hook_name

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_name,
    features=test_feature_idx_gpt,
    batch_size=2048,
    minibatch_size_tokens=32,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=model, # type: ignore
    tokens= all_tokens, # token_dataset[:100000]["tokens"],  # type: ignore
    cfg=feature_vis_config_gpt,
)

Forward passes to cache data for vis:   0%|          | 0/64 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/120 [00:00<?, ?it/s]

In [17]:
dir_name = "jb_features"
os.makedirs(dir_name, exist_ok=True)
orig_architecure = "block_diag"
for feature in test_feature_idx_gpt:
    filename = f"{dir_name}/{feature}_jb_{sae.cfg.hook_layer}_{orig_architecure}_owt_n_seqs_{all_tokens.shape[0]}_new.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    break

Saving feature-centric vis:   0%|          | 0/120 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/120 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/120 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/120 [00:00<?, ?it/s]

KeyboardInterrupt: 

### DFA by source position

In [None]:
# reducing the size for my poor Laptop's small memory :(
all_tokens = all_tokens[:4].clone()

In [None]:
_, cache = model.run_with_cache(all_tokens)
layer = 4
v = cache["v", layer] # (B, T, nh, dh)

In [None]:
attn_weights = cache["pattern", layer] # (B, nh, T, T)

In [None]:
pre_sum_Av = attn_weights.unsqueeze(dim=-1) * v.transpose(dim0=1, dim1=2).unsqueeze(dim=2) # (B, nh, T, T, dh)
pre_sum_Av_cat = pre_sum_Av.permute(dims=(0, 2, 3, 1, 4)).flatten(start_dim=-2, end_dim=-1) # (B, T, T, C)

In [None]:
feature_id = 0
dfa_src = pre_sum_Av_cat @ sae.W_enc[:, feature_id] # (B, T, T)

In [None]:
# To do it for more than one features at a time, let feature_id be a list of feature ids
# then dfa_src will have shape (B, T, T, H) for H being the number of features

# As this computation can be done for a batch of features, we can perhaps do it inside the _get_feature_data function of sae_vis.