# Setup

In [1]:
from google.colab import drive
import shutil

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
try:
    #import google.colab # type: ignore
    #from google.colab import output
    %pip install sae-lens transformer-lens
except:
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting sae-lens
  Downloading sae_lens-3.12.1-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformer-lens
  Downloading transformer_lens-2.2.2-py3-none-any.whl (174 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.3/174.3 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting automated-interpretability<0.0.4,>=0.0.3 (from sae-lens)
  Downloading automated_interpretability-0.0.3-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl (6.9 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m50.9 MB/s[0m eta [3

In [3]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


In [4]:
from torch import nn, Tensor
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple

# load model

In [5]:
from transformer_lens import HookedTransformer

In [6]:
model = HookedTransformer.from_pretrained("tiny-stories-1L-21M")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/269M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


# load sae using saelens

## modify params here

To get feature actvs, use sae_lens class method. To do this, you must load the sae as the sae class (wrapper over torch model).

In [None]:
sae_layer = "blocks.0.hook_mlp_out"

In [None]:
model_name = "tiny-stories-1L-21M"
layer_name = "blocks.0.hook_mlp_out"
hook_layer = 0
d_in = 1024
expa_fac = 8
wandb_project = "sae_" + model_name+"_MLP" + str(hook_layer) + "_df-" + str(d_in * expa_fac)

In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name=layer_name,  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=hook_layer,  # Only one layer in the model.
    d_in=d_in,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor= expa_fac,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project=wandb_project,
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32"
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Run name: 8192-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/269M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-1L-21M into HookedTransformer


Downloading readme:   0%|          | 0.00/415 [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Training SAE:   0%|          | 0/122880000 [00:00<?, ?it/s]
Estimating norm scaling factor:   0%|          | 0/1000 [00:00<?, ?it/s][A
Estimating norm scaling factor:   0%|          | 1/1000 [00:05<1:25:52,  5.16s/it][A
Estimating norm scaling factor:   1%|          | 9/1000 [00:05<07:03,  2.34it/s]  [A
Estimating norm scaling factor:   2%|▏         | 17/1000 [00:05<03:09,  5.19it/s][A
Estimating norm scaling factor:   2%|▏         | 23/1000 [00:05<02:13,  7.33it/s][A
Estimating norm scaling factor:   3%|▎         | 31/1000 [00:05<01:22, 11.76it/s][A
Estimating norm scaling factor:   4%|▍         | 39/1000 [00:05<00:56, 17.15it/s][A
Estimating norm scaling factor:   5%|▍         | 47/1000 [00:05<00:40, 23.38it/s][A
Estimating norm scaling factor:   5%|▌         | 54/1000 [00:06<00:32, 28.76it/s][A
Estimating norm scaling factor:   6%|▌         | 62/1000 [00:06<00:26, 35.80it/s][A
Estimating norm scaling factor:   7%|▋         | 69/1000 [00:08<01:49,  8.49it/s][A
Estimating n

VBox(children=(Label(value='64.071 MB of 64.086 MB uploaded\r'), FloatProgress(value=0.999773537403643, max=1.…

0,1
details/current_l1_coefficient,▁▅██████████████████████████████████████
details/current_learning_rate,████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,▅▅█▇▆▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,▁▅█▇▇▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/CE_loss_score,█▂▁▂▃▃▃▄▄▄▄▄▄▅▄▅▅▄▄▄▅▅▅▅▅▅▅▅▅▆▅▅▅▅▆▆▆▅▅▆
metrics/ce_loss_with_ablation,▃▂▃▁▅▄▃▄▂▄▂▂▄▅▂█▃▃▂▃▆▄▄▃▄▂▂▁▅▆▇▃▄▄▅▄▁▄▃▄

0,1
details/current_l1_coefficient,5.0
details/current_learning_rate,0.0
details/n_training_tokens,122880000.0
losses/auxiliary_reconstruction_loss,0.0
losses/ghost_grad_loss,0.0
losses/l1_loss,31.42966
losses/mse_loss,218.08939
losses/overall_loss,375.23767
metrics/CE_loss_score,0.88816
metrics/ce_loss_with_ablation,8.21341


In [None]:
# https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L13
from sae_lens import SAEConfig

cfg_dict = cfg.to_dict()
cfg_dict['finetuning_scaling_factor'] = 0

sae_cfg = SAEConfig.from_dict(cfg_dict)
# sae = cls(sae_cfg)

In [None]:
from sae_lens import SAE
sae = SAE(sae_cfg)
# sae.cfg = sae_cfg

In [None]:
state_dict = torch.load('/content/drive/MyDrive/tiny-stories-1L-21M_sae_v1.pth')
sae.load_state_dict(state_dict)

<All keys matched successfully>

# load dataset

Need load model tokenizer and sae params before obtain dataset

In [None]:
from datasets import load_dataset

In [None]:
# from transformer_lens.utils import tokenize_and_concatenate

# dataset = load_dataset("roneneldan/TinyStories", streaming=False)
# test_dataset = dataset['validation']

# token_dataset = tokenize_and_concatenate(
#     dataset = test_dataset,
#     tokenizer = model.tokenizer, # type: ignore
#     streaming=True,
#     max_length=sae.cfg.context_size,
#     # add_bos_token=sae.cfg.prepend_bos,
#     add_bos_token=False,
# )

Downloading readme:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (10434 > 2048). Running this sequence through the model will result in indexing errors


In [None]:
# sae.cfg.prepend_bos

True

In [None]:
# batch_tokens = token_dataset[:100]["tokens"]
# batch_tokens.shape

torch.Size([100, 512])

## only get samples with specific concepts/words

In [None]:
# Load the dataset in streaming mode
# dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)
dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)

# Define the maximum sequence length for the model
max_length = 128

# Function to check if text contains the words "love" or "hate"
def contains_love_or_hate(text):
    # return "love" in text.lower() or "hate" in text.lower()
    return "she" in text.lower() or "her" in text.lower()

# Define a function to get tokens in batches with truncation and padding
def get_batch_tokens(dataset, tokenizer, batch_size=32, max_length=128):
    sequences = []
    love_hate_sequences = []
    other_sequences = []
    iterator = iter(dataset)  # Create an iterator from the streamed dataset

    # Separate sequences into those containing "love" or "hate" and those that do not
    for _ in range(batch_size * 2):  # Load more to ensure we get enough samples
        try:
            # Get the next example from the dataset
            example = next(iterator)
            text = example['text']
            if contains_love_or_hate(text):
                tokens = tokenizer.encode(text, max_length=max_length, truncation=True, padding="max_length", return_tensors='pt')
                love_hate_sequences.append(tokens)
            else:
                tokens = tokenizer.encode(text, max_length=max_length, truncation=True, padding="max_length", return_tensors='pt')
                other_sequences.append(tokens)
        except StopIteration:
            # If the dataset ends before reaching the required amount
            break

    # Ensure we have enough samples of each type
    # min_length = min(len(love_hate_sequences), len(other_sequences))
    # love_hate_sequences = love_hate_sequences[:min_length]
    # other_sequences = other_sequences[:min_length]

    others_len = batch_size - len(love_hate_sequences)
    other_sequences = other_sequences[:others_len]

    # Combine sequences to form the batch
    # sequences = love_hate_sequences[:batch_size//2] + other_sequences[:batch_size//2]
    sequences = love_hate_sequences + other_sequences

    # pdb.set_trace()

    if sequences:
        batch_tokens = torch.cat(sequences, dim=0).squeeze(1)
        return batch_tokens
    else:
        return None

# Get a batch of tokens
batch_tokens = get_batch_tokens(dataset, model.tokenizer, batch_size=100, max_length=max_length)
if batch_tokens is not None:
    print(batch_tokens.shape)
else:
    print("No data to load.")

Downloading readme:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


torch.Size([195, 128])


# model 1- interpret features

## get LLM actvs

In [None]:
layer_name = sae_layer

In [None]:
h_store = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model.cfg.d_model), device=model.cfg.device)
h_store.shape

torch.Size([195, 128, 1024])

In [None]:
def store_h_hook(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store[:] = pattern  # this works b/c changes values, not replaces entire thing

In [None]:
model.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook),
    ]
)

## get SAE actvs

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts = sae.encode(h_store)

In [None]:
def count_nonzero_features(feature_acts):
    # Count the number of 0s in the tensor
    num_zeros = (feature_acts == 0).sum().item()

    # Count the number of nonzeroes in the tensor
    num_ones = (feature_acts > 0).sum().item()

    # Calculate the percentage of 1s over 0s
    if num_zeros > 0:
        perc_ones_over_total = (num_ones / (num_ones + num_zeros)) * 100
    else:
        perc_ones_over_total = float('inf')  # Handle division by zero
    return perc_ones_over_total
count_nonzero_features(feature_acts)

0.1506350590632512

## save actvs

Now you have to save actvs, bc saelens not compatible with umap OR cosine sim lib

In [None]:
import pickle

In [None]:
with open('fActs_ts_1L_21M_sheHer.pkl', 'wb') as f:
    pickle.dump(feature_acts, f)

In [None]:
!cp fActs_ts_1L_21M_sheHer.pkl /content/drive/MyDrive/

In [None]:
# check if saved
file_path = '/content/drive/MyDrive/fActs_ts_1L_21M_sheHer.pkl'
with open(file_path, 'rb') as f:
    feature_acts = pickle.load(f)

## get top features

In [None]:
# Get the top k largest activations for feature neurons, not batch seq. use , dim=-1
# if want to get highest batch, use dim=0
feat_k = 15
top_acts_values, top_acts_indices = feature_acts.topk(feat_k, dim=-1)

print(top_acts_indices.shape)
top_acts_values.shape

torch.Size([195, 128, 15])


torch.Size([195, 128, 15])

## interpret top features by dataset examples

In [None]:
def highest_activating_tokens(
    feature_acts,
    feature_idx: int,
    k: int = 10,  # num batch_seq samples
    batch_tokens=None
) -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = batch_tokens.shape

    # Get the top k largest activations for only targeted feature
    # need to flatten (batch,seq) into batch*seq first because it's ANY batch_seq, even if in same batch or same pos
    flattened_feature_acts = feature_acts[:, :, feature_idx].reshape(-1)

    top_acts_values, top_acts_indices = flattened_feature_acts.topk(k)
    # top_acts_values should be 1D
    # top_acts_indices should be also be 1D. Now, turn it back to 2D
    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values

In [None]:
from rich import print as rprint
def display_top_sequences(top_acts_indices, top_acts_values, batch_tokens):
    s = ""
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # s += f'{batch_idx}\n'
        s += f'batchID: {batch_idx}, '
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 5, 0)
        # seq_end = min(seq_idx + 5, all_tokens.shape[1])
        seq_end = min(seq_idx + 5, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        s += f'Act = {value:.2f}, Seq = "{seq}"\n'

    rprint(s)

In [None]:
# get top samp_m tokens for all top feat_k feature neurons
samp_m = 5

# top features in matching pair with model B
# top_feats = [3383, 8341]

for feature_idx in top_acts_indices[0, -1, :]:
# for feature_idx in top_feats:
    feature_idx = feature_idx.item()
    print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts, feature_idx, samp_m, batch_tokens=batch_tokens)
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=batch_tokens)

Feature:  16244


Feature:  16220


Feature:  2674


Feature:  9328


Feature:  8042


Feature:  9887


Feature:  2136


Feature:  10850


Feature:  1052


Feature:  4699


Feature:  5451


Feature:  5536


Feature:  1


Feature:  0


Feature:  2


# model 2- interpret features

## load model and sae

In [None]:
model_2 = HookedTransformer.from_pretrained("tiny-stories-2L-33M")



config.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/323M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [None]:
sae_layer = "blocks.0.hook_mlp_out"

model_name = "tiny-stories-2L-33M"
layer_name = "blocks.1.hook_mlp_out"
hook_layer = 1
d_in = 1024
wandb_project = model_name+"_MLP1_sae"

In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name=layer_name,  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=hook_layer,  # Only one layer in the model.
    d_in=d_in,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project=wandb_project,
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32"
)

Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 30000
Total wandb updates: 1000
n_tokens_per_feature_sampling_window (millions): 2097.152
n_tokens_per_dead_feature_window (millions): 2097.152
We will reset the sparsity calculation 30 times.
Number tokens in sparsity calculation window: 4.10e+06


In [None]:
# https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L13
from sae_lens import SAEConfig

cfg_dict = cfg.to_dict()
cfg_dict['finetuning_scaling_factor'] = 0

sae_cfg = SAEConfig.from_dict(cfg_dict)

In [None]:
from sae_lens import SAE
sae_2 = SAE(sae_cfg)

In [None]:
state_dict = torch.load('/content/drive/MyDrive/tiny-stories-2L-33M_MLP0_sae.pth')
sae_2.load_state_dict(state_dict)

<All keys matched successfully>

## get LLM actvs

In [None]:
layer_name = sae_layer

In [None]:
h_store = torch.zeros((batch_tokens.shape[0], batch_tokens.shape[1], model.cfg.d_model), device=model.cfg.device)
h_store.shape

torch.Size([195, 128, 1024])

In [None]:
def store_h_hook(
    pattern: Float[Tensor, "batch seqlen d_model"],
    hook
):
    h_store[:] = pattern  # this works b/c changes values, not replaces entire thing

In [None]:
model_2.run_with_hooks(
    batch_tokens,
    return_type = None,
    fwd_hooks=[
        (layer_name, store_h_hook),
    ]
)

## get SAE actvs

In [None]:
sae_2.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with torch.no_grad():
    feature_acts_2 = sae.encode(h_store)

In [None]:
def count_nonzero_features(feature_acts):
    # Count the number of 0s in the tensor
    num_zeros = (feature_acts == 0).sum().item()

    # Count the number of nonzeroes in the tensor
    num_ones = (feature_acts > 0).sum().item()

    # Calculate the percentage of 1s over 0s
    if num_zeros > 0:
        perc_ones_over_total = (num_ones / (num_ones + num_zeros)) * 100
    else:
        perc_ones_over_total = float('inf')  # Handle division by zero
    return perc_ones_over_total
count_nonzero_features(feature_acts_2)

35.92024705348871

## save actvs

Now you have to save actvs, bc saelens not compatible with umap OR cosine sim lib

In [None]:
with open('fActs_ts_2L_33M_sheHer.pkl', 'wb') as f:
    pickle.dump(feature_acts_2, f)

In [None]:
!cp fActs_ts_2L_33M_sheHer.pkl /content/drive/MyDrive/

In [None]:
# check if saved
file_path = '/content/drive/MyDrive/fActs_ts_2L_33M_sheHer.pkl'
with open(file_path, 'rb') as f:
    feature_acts_2 = pickle.load(f)

## get top features

In [None]:
# Get the top k largest activations for feature neurons, not batch seq. use , dim=-1
# if want to get highest batch, use dim=0
feat_k = 15
top_acts_values, top_acts_indices = feature_acts_2.topk(feat_k, dim=-1)

print(top_acts_indices.shape)
top_acts_values.shape

torch.Size([195, 128, 15])


torch.Size([195, 128, 15])

## interpret top features by dataset examples

In [None]:
def highest_activating_tokens(
    feature_acts,
    feature_idx: int,
    k: int = 10,  # num batch_seq samples
    batch_tokens=None
) -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = batch_tokens.shape

    # Get the top k largest activations for only targeted feature
    # need to flatten (batch,seq) into batch*seq first because it's ANY batch_seq, even if in same batch or same pos
    flattened_feature_acts = feature_acts[:, :, feature_idx].reshape(-1)

    top_acts_values, top_acts_indices = flattened_feature_acts.topk(k)
    # top_acts_values should be 1D
    # top_acts_indices should be also be 1D. Now, turn it back to 2D
    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values

In [None]:
from rich import print as rprint
def display_top_sequences(top_acts_indices, top_acts_values, batch_tokens):
    s = ""
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # s += f'{batch_idx}\n'
        s += f'batchID: {batch_idx}, '
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq_start = max(seq_idx - 5, 0)
        # seq_end = min(seq_idx + 5, all_tokens.shape[1])
        seq_end = min(seq_idx + 5, batch_tokens.shape[1])
        seq = ""
        # Loop over the sequence, adding each token to the string (highlighting the token with the large activations)
        for i in range(seq_start, seq_end):
            # new_str_token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            new_str_token = model.to_single_str_token(batch_tokens[batch_idx, i].item()).replace("\n", "\\n").replace("<|BOS|>", "|BOS|")
            if i == seq_idx:
                new_str_token = f"[bold u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        s += f'Act = {value:.2f}, Seq = "{seq}"\n'

    rprint(s)

In [None]:
# get top samp_m tokens for all top feat_k feature neurons
samp_m = 5

for feature_idx in top_acts_indices[0, -1, :]:
    feature_idx = feature_idx.item()
    print('Feature: ', feature_idx)
    ds_top_acts_indices, ds_top_acts_values = highest_activating_tokens(feature_acts_2, feature_idx, samp_m, batch_tokens=batch_tokens)
    display_top_sequences(ds_top_acts_indices, ds_top_acts_values, batch_tokens=batch_tokens)

Feature:  6469


Feature:  761


Feature:  7208


Feature:  15880


Feature:  9765


Feature:  707


Feature:  9139


Feature:  12314


Feature:  3103


Feature:  250


Feature:  10152


Feature:  11183


Feature:  5833


Feature:  5553


Feature:  12665
