In [103]:
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils

import torch
from datasets import load_dataset

from IPython.display import HTML, display

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-xl')



Loaded pretrained model gpt2-xl into HookedTransformer


In [72]:
from sae_lens.training.config import LanguageModelSAERunnerConfig

hook_point = "blocks.20.hook_resid_pre"
bs = 64

cf = {
  "model_name": "gpt2-xl",
  "hook_point": "blocks.20.hook_resid_pre",
  "hook_point_layer": 20,
  "hook_point_head_index": None,
  "dataset_path": "Skylion007/openwebtext",
  "is_dataset_tokenized": False,
  "context_size": 128,
  "use_cached_activations": False,
  "cached_activations_path": "activations/Skylion007_openwebtext/gpt2-small/blocks.1.hook_resid_pre",
  "d_in": 1600,
  "n_batches_in_buffer": bs,
  "total_training_tokens": 300000000,
  "store_batch_size": bs,
  "device": device,
  "seed": 42,
  "dtype": "torch.float16",
  "b_dec_init_method": "geometric_median",
  "expansion_factor": 32,
  "from_pretrained_path": None,
  "l1_coefficient": 0.00008,
  "lr": 0.0004,
  "lr_scheduler_name": None,
  "lr_warm_up_steps": 5000,
  "train_batch_size": 4096,
  "use_ghost_grads": False,
  "feature_sampling_window": 1000,
  "feature_sampling_method": None,
  "resample_batches": 1028,
  "feature_reinit_scale": 0.2,
  "dead_feature_window": 5000,
  "dead_feature_estimation_method": "no_fire",
  "dead_feature_threshold": 1e-8,
  "log_to_wandb": True,
  "wandb_project": "mats_sae_training_gpt2_small_resid_pre_5",
  "wandb_entity": None,
  "wandb_log_frequency": 100,
  "n_checkpoints": 10,
  "checkpoint_path": "checkpoints/mm179kd2",
  "d_sae": 1600*32,
  "tokens_per_buffer": 128,
  "run_name": "24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08"
}
config = cf
var_names = LanguageModelSAERunnerConfig.__init__.__code__.co_varnames
config = {k: v for k, v in config.items() if k in var_names}
cfg = LanguageModelSAERunnerConfig(
    **config
)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device

layer = cfg.hook_point_layer


Run name: 51200-L1-8e-05-LR-0.0004-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 488
Total wandb updates: 4
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 4.10e+06


In [73]:
from safetensors import safe_open
tensors = {}
with safe_open("gpt2-20.safetensors", framework="pt") as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k).cuda()
        # if k == "b_enc":
        #     tensors[k] -= 0.18
tensors["b_enc"] += tensors["b_dec"] @ tensors["W_enc"]
sparse_autoencoder.load_state_dict(tensors, strict=False)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device

In [101]:
def top_acts_at_pos(text, pos=-1, silent=True, prepend_bos=True):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sparse_autoencoder(hidden_state).feature_acts
    print(feature_acts.shape)
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = torch.topk(feature_acts, 10)
    return top_v, top_i

top_acts_at_pos("Anger something something", pos=-1)


torch.Size([1, 51200])


(tensor([55.3438, 49.4688, 25.2812, 22.7031, 22.5625, 21.6250, 20.6094, 19.8750,
         16.4375, 15.4297], device='cuda:0', dtype=torch.float16,
        grad_fn=<TopkBackward0>),
 tensor([49395,   126,   409, 14643, 44188, 39324,  1104, 22759,  5883,  5913],
        device='cuda:0'))

In [106]:
top_acts_at_pos("Anger", pos=-1) # [126, 20811, 4524 ...]

torch.Size([1, 51200])


(tensor([65.1875, 47.1250, 26.8281, 26.1562, 25.6562, 25.4219, 22.2188, 19.3125,
         18.8438, 17.2656], device='cuda:0', dtype=torch.float16,
        grad_fn=<TopkBackward0>),
 tensor([  126, 20811,  4524, 44188,  4364,   409, 12006, 33085, 25116, 22759],
        device='cuda:0'))

In [108]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")

tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]

In [109]:
all_tokens.shape

torch.Size([325017, 128])

In [112]:
# okay so 126 is the anger feature right? Wrong!
# 126 activates on most text
selected_feature = 126
activation_count = 0
total = 0
for i in range(10):
    logits, cache = model.run_with_cache(all_tokens[i])
    hidden_state = cache[hook_point][0]
    feature_acts = sparse_autoencoder(hidden_state).feature_acts # shape [128, n_features]
    selected_acts = feature_acts[:, selected_feature]

    activation_count += (selected_acts > 0).sum().item()
    total += selected_acts.shape[0]

print(activation_count/total)
# this feature activates on 99% of all tokens!!!

0.9921875


In [82]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = [126, 20811, 409]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=True,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sparse_autoencoder,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/3 [00:00<?, ?it/s]

In [105]:
import os
from IPython.display import FileLink

vis_dir = "feature_vis"
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display(FileLink(filename))

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

In [87]:
top_acts_at_pos(" any text you like!", pos=-1)

(tensor([60.1562, 48.0938, 24.2656, 20.2188, 18.5781, 14.2656, 14.1641, 12.8672,
         12.0781, 11.9766], device='cuda:0', dtype=torch.float16,
        grad_fn=<TopkBackward0>),
 tensor([  126, 41137,   409, 28732, 44188, 35851, 22759, 43778,  5883, 18118],
        device='cuda:0'))

In [113]:
# also, another crazy thing about this SAE:
# the 0th token causes some features to activate with crazy high magnitude
top_acts_at_pos(" hello", pos=0)

torch.Size([1, 51200])


(tensor([656.0000, 518.5000, 505.7500, 494.5000, 487.0000, 478.0000, 468.7500,
         453.5000, 419.0000, 385.0000], device='cuda:0', dtype=torch.float16,
        grad_fn=<TopkBackward0>),
 tensor([30958, 47390, 41028,  1212, 40249,  3370,  9507, 10284, 30590, 42322],
        device='cuda:0'))

In [114]:
# same thing with and without BOS
top_acts_at_pos(" hello", pos=0, prepend_bos=False)

torch.Size([1, 51200])


(tensor([620.0000, 491.2500, 478.7500, 464.2500, 448.2500, 448.0000, 442.0000,
         423.2500, 391.5000, 383.7500], device='cuda:0', dtype=torch.float16,
        grad_fn=<TopkBackward0>),
 tensor([30958, 47390,  1212, 41028, 40249,  9507,  3370, 10284, 30590, 42322],
        device='cuda:0'))

In [None]:
# my takeaway is this is a bad SAE.