# Setup

In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"

if IN_COLAB:
    # Install packages
    %pip install jaxtyping
    %pip install transformer_lens sae_lens
    %pip install git+https://github.com/callummcdougall/eindex.git

    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists(f"/content/{chapter}"):
        !wget https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_3.0-main/chapter1_transformer_interp/exercises/*'
        sys.path.append(f"/content/{repo}-main/{chapter}/exercises")
        os.remove("/content/main.zip")
        os.rename(f"{repo}-main/{chapter}", chapter)
        os.rmdir(f"{repo}-main")
        os.chdir(f"{chapter}/exercises")
else:
    chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
    sys.path.append(chapter_dir + f"{chapter}/exercises")

Collecting sae-lens
  Downloading sae_lens-3.12.0-py3-none-any.whl.metadata (4.9 kB)
Collecting automated-interpretability<0.0.4,>=0.0.3 (from sae-lens)
  Downloading automated_interpretability-0.0.3-py3-none-any.whl.metadata (817 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Using cached datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting plotly<6.0.0,>=5.19.0 (from sae-lens)
  Downloading plotly-5.22.0-py3-none-any.whl.metadata (7.1 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dotenv<2.0.0,>=1.0.1 (from sae-lens)
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pyzmq==26.0.0 (from sae-lens)
  

In [37]:
import torch

import transformer_lens
from transformer_lens import HookedTransformer

import sae_lens
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner

import circuitsvis as cv

from datasets import load_dataset

In [1]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Model 1 (Baseline)

In [31]:
# Load a model (BASELINE)

model1 = transformer_lens.HookedTransformer.from_pretrained(
    "gpt2-small", device=device
)  # This will wrap huggingface models and has lots of nice utilities.

# gelu_1l = HookedTransformer.from_pretrained("gelu-1l", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


## Model 1 (Baseline) Basic Information Gathering

| Parameter | Value |
|-----------|-------|
| Model Name | gpt2-small |
| Number of Parameters | 85M |
| Number of Layers | 12 |
| Model Dimension (d_model) | 768 |
| Number of Attention Heads | 12 |
| Activation Function | GELU |
| Context Window Size | 1024 |
| Vocabulary Size | 50257 |
| Dimension per Head | 64 |
| MLP Dimension | 3072 |

We use `.cfg` to find the basic architectural info about the model:

In [32]:
display(model1.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'mps',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_embedding_

In [33]:
# Print model information
print(f"Model name: gpt2-small")
print(f"Number of parameters: {model1.cfg.n_params}")
print(f"Number of layers: {model1.cfg.n_layers}")
print(f"Model dimension: {model1.cfg.d_model}")
print(f"Number of heads: {model1.cfg.n_heads}")
print(f"Activation function: {model1.cfg.act_fn}")
print(f"Context window size: {model1.cfg.n_ctx}")
print(f"Vocabulary size: {model1.cfg.d_vocab}")
print(f"Dimension per head: {model1.cfg.d_head}")
print(f"MLP dimension: {model1.cfg.d_mlp}")

# Demonstrate next token prediction
prompt = "The quick brown fox"
input_ids = model1.to_tokens(prompt)
logits = model1(input_ids)
next_token_id = torch.argmax(logits[0, -1]).item()
next_token = model1.to_string(next_token_id)
print(f"\nPrompt: '{prompt}'")
print(f"Predicted next token: '{next_token}'")

# Calculate loss
full_prompt = "The quick brown fox jumps over the lazy dog"
tokens = model1.to_tokens(full_prompt)
logits = model1(tokens)
loss = model1.loss_fn(logits[:, :-1, :], tokens[:, 1:])
print(f"\nLoss: {loss.item()}")

Model name: gpt2-small
Number of parameters: 84934656
Number of layers: 12
Model dimension: 768
Number of heads: 12
Activation function: gelu_new
Context window size: 1024
Vocabulary size: 50257
Dimension per head: 64
MLP dimension: 3072



Prompt: 'The quick brown fox'
Predicted next token: ' is'

Loss: 8.791129112243652


In [35]:
# here we use generate to get 10 completions with temperature 1
for i in range(5):
    display(
        model1.generate(
            "The doctor was a",
            stop_at_eos=False,  # avoids a bug on MPS
            temperature=1,
            verbose=False,
            max_new_tokens=50,
        )
    )

'The doctor was a whip after he was welcomed to Russia early this year.\n\nA 19-year-old doctor who gained awards for his busy and aggressive practice, valued his\'very foghorny nature\' at $1 million for his "deliberative'

"The doctor was a tax haven priestess in Anguera, out of the Egboxtest.) Most of him thought the Tibetans had come out from beginning to end the run of Tibetraining. It wasn't. But seeing one monk dying of a mysterious conversion"

'The doctor was a Catholic and his character was too funny. But rather than go into dramatic detail, he carefully summed up what the Doctor did. He understands how the human body responds to foreign substances.\n\nKavanagh: Yeah, the human body is retuned'

'The doctor was a science fellow who worked on a group of conditions at the London School of Hygiene and Tropical Medicine. In 2009, she was at the forefront of trying curing cancer skin melanoma. Her program had nine recipient samples...\n\n"I have this genetic'

"The doctor was a well-known member of the Church of Peter and before him was a'sectarian' fallen priest (known by his native bearded Greek), an avowed Notician. Her outlook was at once technocratic by pragmatic and fascinating as well as spiritual"

In [36]:
from transformer_lens.utils import test_prompt

# Test the model with a prompt
test_prompt(
    "Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,",
    " Lily",
    model1,
    prepend_space_to_answer=False,
)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']
Tokenized answer: [' Lily']


Top 0th token. Logit: 17.66 Prob: 55.63% Token: | she|
Top 1th token. Logit: 16.47 Prob: 16.99% Token: | Lily|
Top 2th token. Logit: 15.02 Prob:  3.96% Token: | her|
Top 3th token. Logit: 14.62 Prob:  2.65% Token: | the|
Top 4th token. Logit: 14.42 Prob:  2.19% Token: | a|
Top 5th token. Logit: 13.87 Prob:  1.25% Token: | they|
Top 6th token. Logit: 13.75 Prob:  1.12% Token: | there|
Top 7th token. Logit: 13.18 Prob:  0.63% Token: | it|
Top 8th token. Logit: 13.10 Prob:  0.58% Token: | when|
Top 9th token. Logit: 13.08 Prob:  0.57% Token: | though|


In [40]:

# Let's make a longer prompt and see the log probabilities of the tokens
example_prompt = """Two students, Sean and Rini, are really excited about mechanistic interpretability!"""
logits, cache = model1.run_with_cache(example_prompt)
cv.logits.token_log_probs(
    model1.to_tokens(example_prompt),
    model1(example_prompt)[0].log_softmax(dim=-1),
    model1.to_string,
)
# hover on the output to see the result.

The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model. This can be done with logits, cache = model.run_with_cache(tokens).

In [14]:
test_input = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
test_tokens = model1.to_tokens(test_input)
logits, cache = model1.run_with_cache(test_tokens, remove_batch_dim=True)

Next we inspect the cache object and see that it contains a very large number of keys, each one corresponding to a different activation in the model.

In [16]:
attn_patterns_layer_0 = cache["pattern", 0]
display(attn_patterns_layer_0)

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [7.6265e-01, 2.3735e-01, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [5.7104e-01, 3.0543e-01, 1.2353e-01,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [2.8672e-01, 4.9603e-04, 3.6867e-04,  ..., 1.0751e-01,
          0.0000e+00, 0.0000e+00],
         [7.9983e-01, 1.2348e-03, 5.6462e-05,  ..., 4.5853e-02,
          4.8726e-02, 0.0000e+00],
         [6.0642e-01, 3.7502e-04, 8.0836e-05,  ..., 4.6981e-02,
          2.0655e-01, 1.2738e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [9.6329e-01, 3.6710e-02, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [7.0548e-01, 2.1403e-01, 8.0485e-02,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         ...,
         [7.2992e-02, 7.2786e-04, 6.7609e-04,  ..., 2.7623e-02,
          0.000

# Model 2 (SAE)

In [None]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
    device = device
)

In [None]:
# total_training_steps = 30_000  # probably we should do more
total_training_steps = 1  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

# Initialize the SAE
cfg = LanguageModelSAERunnerConfig(
    ######################################################################
    # Data Generating Function (Model + Training Distibuion)
    ######################################################################
    model_name="gpt2-small",
    hook_name="blocks.0.hook_mlp_out",
    hook_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="monology/pile-uncopyrighted",
    is_dataset_tokenized=False,
    streaming=True,
    
    ######################################################################
    # SAE Parameters
    ######################################################################
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    
    ######################################################################
    # Training Parameters
    ######################################################################
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    
    ######################################################################
    # Activation Store Parameters
    ######################################################################
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    
    ######################################################################
    # Resampling protocol
    ######################################################################
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    
    ######################################################################
    # WANDB
    ######################################################################
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    
    ######################################################################
    # Misc
    ######################################################################
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32"
)

sae = SAETrainingRunner(cfg).run()

# The SAE is now trained and attached to model2
print("SAE training complete.")

In [44]:
# Reduced configuration for quick training
cfg = LanguageModelSAERunnerConfig(
    model_name="gpt2-small",
    hook_name="blocks.0.hook_mlp_out",
    hook_layer=0,
    d_in=768,  # GPT-2 Small hidden size
    dataset_path="monology/pile-uncopyrighted",
    is_dataset_tokenized=False,
    streaming=True,
    
    expansion_factor=2,  # Reduced from 16 to speed up training
    
    lr=1e-4,
    l1_coefficient=1.0,
    
    train_batch_size_tokens=1024,  # Reduced batch size
    context_size=128,  # Reduced context size
    
    n_batches_in_buffer=16,
    training_tokens=1_000_000,  # Reduced number of training tokens
    store_batch_size_prompts=8,
    
    log_to_wandb=False,  # Disable wandb logging for this quick test
    
    device="mps",
    seed=42,
    dtype="float32"
)

# Initialize and train the SAE
runner = SAETrainingRunner(cfg)
sae = runner.run()

# Save the trained SAE
torch.save(sae.state_dict(), "trained_sae.pt")
print("SAE training complete and model saved.")

# To restore the SAE later:
restored_sae = runner.init_sae()  # This initializes a new SAE with the same config
restored_sae.load_state_dict(torch.load("trained_sae.pt"))
print("SAE restored from file.")

Run name: 1536-L1-1.0-LR-0.0001-Tokens-1.000e+06
n_tokens_per_buffer (millions): 0.016384
Lower bound: n_contexts_per_buffer (millions): 0.000128
Total training steps: 976
Total wandb updates: 97
n_tokens_per_feature_sampling_window (millions): 262.144
n_tokens_per_dead_feature_window (millions): 131.072
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 2.05e+06
Loaded pretrained model gpt2-small into HookedTransformer


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



Token indices sequence length is longer than the specified maximum sequence length for this model (3180 > 1024). Running this sequence through the model will result in indexing errors
Objective value: 191164.5625:   2%|▏         | 2/100 [00:00<00:06, 16.17it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
900| MSE Loss 139.777 | L1 110.560:  92%|█████████▏| 921600/1000000 [03:28<00:17, 4428.68it/s]

SAE training complete and model saved.





AttributeError: 'SAETrainingRunner' object has no attribute 'init_sae'

# Model 3 (Fine-tuned Baseline)