In [1]:
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils

import torch
from datasets import load_dataset

from IPython.display import HTML, display

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcb11bde920>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained('gpt2-xl')



Loaded pretrained model gpt2-xl into HookedTransformer


In [3]:
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.config import LanguageModelSAERunnerConfig

hook_point = "blocks.20.hook_resid_pre"
bs = 64

cf = {
  "model_name": "gpt2-xl",
  "hook_point": "blocks.20.hook_resid_pre",
  "hook_point_layer": 20,
  "hook_point_head_index": None,
  "dataset_path": "Skylion007/openwebtext",
  "is_dataset_tokenized": False,
  "context_size": 128,
  "use_cached_activations": False,
  "cached_activations_path": "activations/Skylion007_openwebtext/gpt2-small/blocks.1.hook_resid_pre",
  "d_in": 1600,
  "n_batches_in_buffer": bs,
  "total_training_tokens": 300000000,
  "store_batch_size": bs,
  "device": device,
  "seed": 42,
  "dtype": "torch.float16",
  "b_dec_init_method": "geometric_median",
  "expansion_factor": 32,
  "from_pretrained_path": None,
  "l1_coefficient": 0.00008,
  "lr": 0.0004,
  "lr_scheduler_name": None,
  "lr_warm_up_steps": 5000,
  "train_batch_size": 4096,
  "use_ghost_grads": False,
  "feature_sampling_window": 1000,
  "feature_sampling_method": None,
  "resample_batches": 1028,
  "feature_reinit_scale": 0.2,
  "dead_feature_window": 5000,
  "dead_feature_estimation_method": "no_fire",
  "dead_feature_threshold": 1e-8,
  "log_to_wandb": True,
  "wandb_project": "mats_sae_training_gpt2_small_resid_pre_5",
  "wandb_entity": None,
  "wandb_log_frequency": 100,
  "n_checkpoints": 10,
  "checkpoint_path": "checkpoints/mm179kd2",
  "d_sae": 1600*32,
  "tokens_per_buffer": 128,
  "run_name": "24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08"
}
config = cf
var_names = LanguageModelSAERunnerConfig.__init__.__code__.co_varnames
config = {k: v for k, v in config.items() if k in var_names}
cfg = LanguageModelSAERunnerConfig(
    **config
)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device

layer = cfg.hook_point_layer


Run name: 51200-L1-8e-05-LR-0.0004-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 488
Total wandb updates: 4
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 4.10e+06


In [4]:
from safetensors import safe_open
tensors = {}
with safe_open("gpt2-20.safetensors", framework="pt") as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k).cuda()
        # if k == "b_enc":
        #     tensors[k] -= 0.18
tensors["b_enc"] += tensors["b_dec"] @ tensors["W_enc"]
sparse_autoencoder.load_state_dict(tensors, strict=False)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device

In [5]:
def top_acts_at_pos(text, pos=-1, silent=True, prepend_bos=True, n_top=10):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sparse_autoencoder(hidden_state).feature_acts
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = torch.topk(feature_acts, n_top)
    return top_v, top_i

top_acts_at_pos("Anger something something", pos=-1)


(tensor([50.7812, 49.5312, 26.0781, 20.3281, 20.2812, 19.3594, 18.5000, 16.7812,
         16.3906, 13.1094], device='cuda:0', dtype=torch.float16),
 tensor([49395,   126,   409, 39324, 44188, 14643, 22759,  1104,  5883,  5913],
        device='cuda:0'))

In [6]:
top_acts_at_pos("Anger", pos=-1) # [126, 20811, 4524 ...]

(tensor([65.2500, 43.9062, 26.2188, 23.8750, 23.5938, 21.7500, 19.1719, 16.3906,
         15.8984, 13.6094], device='cuda:0', dtype=torch.float16),
 tensor([  126, 20811,   409, 44188,  4524,  4364, 12006, 33085, 22759, 25116],
        device='cuda:0'))

In [7]:
# # data = load_dataset("NeelNanda/c4-code-20k", split="train")
# data = load_dataset("NeelNanda/pile-10k", split="train")

# tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
# tokenized_data = tokenized_data.shuffle(42)
# all_tokens = tokenized_data["tokens"]

In [9]:
from tqdm.auto import tqdm

from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
import transformer_lens.loading_from_pretrained as tllfp
loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg)
_, _, activation_store = loader.load_sae_training_group_session()


def get_tokens(
    activation_store,
    n_batches_to_sample_from: int = 2**13,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]


all_tokens = get_tokens(activation_store)  # should take a few minutes



Loaded pretrained model gpt2-xl into HookedTransformer
Moving model to device:  cuda


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Run name: 51200-L1-8e-05-LR-0.0004-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 488
Total wandb updates: 4
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 4.10e+06
Run name: 51200-L1-8e-05-LR-0.0004-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.002048
Total training steps: 488
Total wandb updates: 4
n_tokens_per_feature_sampling_window (millions): 524.288
n_tokens_per_dead_feature_window (millions): 2621.44
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 4.10e+06


  0%|          | 0/8192 [00:00<?, ?it/s]

In [10]:
all_tokens.shape

torch.Size([24576, 128])

In [11]:
# okay so 126 is the anger feature right? Wrong!
# 126 activates on most text
selected_feature = 126
activation_count = 0
total = 0
for i in range(10):
    logits, cache = model.run_with_cache(all_tokens[i])
    hidden_state = cache[hook_point][0]
    feature_acts = sparse_autoencoder(hidden_state).feature_acts # shape [128, n_features]
    selected_acts = feature_acts[:, selected_feature]

    activation_count += (selected_acts > 0).sum().item()
    total += selected_acts.shape[0]

print(activation_count/total)
# this feature activates on 99% of all tokens!!!

0.98984375


In [12]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = [126, 20811, 409]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=True,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sparse_autoencoder,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
import os
from IPython.display import FileLink

vis_dir = "feature_vis"
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display(FileLink(filename))

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
top_acts_at_pos(" any text you like!", pos=-1)

(tensor([60.2188, 43.5312, 25.0625, 16.5625, 16.2969, 12.7969, 12.0156, 10.2422,
          9.4375,  9.1172], device='cuda:0', dtype=torch.float16),
 tensor([  126, 41137,   409, 28732, 44188, 22759,  5883, 43778, 35851, 18118],
        device='cuda:0'))

In [15]:
# also, another crazy thing about this SAE:
# the 0th token causes some features to activate with crazy high magnitude
top_acts_at_pos(" hello", pos=0)

(tensor([652.5000, 517.0000, 501.7500, 490.7500, 485.2500, 473.5000, 466.2500,
         450.2500, 415.7500, 377.7500], device='cuda:0', dtype=torch.float16),
 tensor([30958, 47390, 41028,  1212, 40249,  3370,  9507, 10284, 30590, 17048],
        device='cuda:0'))

In [16]:
# same thing with and without BOS
top_acts_at_pos(" hello", pos=0, prepend_bos=False)

(tensor([616.5000, 490.0000, 475.0000, 460.2500, 446.5000, 445.5000, 437.5000,
         420.0000, 388.2500, 376.2500], device='cuda:0', dtype=torch.float16),
 tensor([30958, 47390,  1212, 41028, 40249,  9507,  3370, 10284, 30590, 42322],
        device='cuda:0'))

In [17]:
# my takeaway is this is a bad SAE.

In [18]:
# checking that you can indeed find a reasonable "Angry" feature.
# get features for angry, remove features which also activate on "Calm"

angry_vals, angry_ids = top_acts_at_pos("Anger", pos=-1, n_top=100)
calm_vals, calm_ids = top_acts_at_pos("Calm", pos=None, n_top=200)

angry_vals = [t.item() for t in angry_vals]
angry_ids = [t.item() for t in angry_ids]
calm_vals = [t.item() for t in calm_vals]
calm_ids = [t.item() for t in calm_ids]

angry = zip(angry_vals, angry_ids)
calm = zip(calm_vals, calm_ids)

# remove zero-valued calm ids
calm = [(v, i) for v, i in calm if v > 0]
calm_set = set([i for v, i in calm])

# remove calm ids from angry
angry = [(v, i) for v, i in angry if i not in calm_set]

In [19]:
angry[:10]

[(43.90625, 20811),
 (21.75, 4364),
 (16.390625, 33085),
 (11.4921875, 25473),
 (10.84375, 34590),
 (10.1796875, 11977),
 (8.78125, 21255),
 (8.1171875, 12346),
 (6.46875, 21981),
 (5.80859375, 37635)]

In [20]:
test_feature_idx_gpt = [20811, 4364, 33085]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=False,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sparse_autoencoder,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display(FileLink(filename))

In [21]:
# 20811 # angry/sad/irritated feature
# 4364 activates on words that end in "er" or "ner"
# 33085 violence/aggression feature