# Chapter 9: Sparse Autoencoders - Hands-On Notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ttsugriy/mechinterp-first-principles/blob/main/notebooks/09-sparse-autoencoders.ipynb)

This notebook accompanies [Chapter 9: Sparse Autoencoders](https://ttsugriy.github.io/mechinterp-first-principles/chapters/09-sparse-autoencoders.html) from *First Principles of Mechanistic Interpretability*.

**What you'll do:**
1. Load a pre-trained SAE for GPT-2 Small
2. Extract features from model activations
3. Explore what features represent
4. Try steering the model with a feature

**Time:** ~30 minutes

**Prerequisites:** Basic Python, having read Chapter 9

## Setup

First, let's install the required libraries. This may take a few minutes.

In [None]:
# Step 1: Install libraries (run this cell, then restart runtime!)
!pip install -q transformer-lens sae-lens einops jaxtyping

print("Installation complete!")
print("Now restart the runtime: Runtime -> Restart runtime")
print("Then skip this cell and run the next one.")

In [None]:
# Step 2: Import libraries (run this after restarting runtime)
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer
from sae_lens import SAE

# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)

## 1. Load the Model and SAE

We'll use GPT-2 Small and a pre-trained SAE from the SAE Lens library.

In [None]:
# Load GPT-2 Small
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
print(f"Model loaded: {model.cfg.n_layers} layers, {model.cfg.d_model} dimensions")

In [None]:
# Load a pre-trained SAE for layer 8 residual stream
# This SAE was trained by Joseph Bloom and is available through SAE Lens
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.8.hook_resid_pre",
    device=device
)

print(f"SAE loaded!")
print(f"  Input dimension: {sae.cfg.d_in}")
print(f"  Number of features: {sae.cfg.d_sae}")
print(f"  Expansion factor: {sae.cfg.d_sae / sae.cfg.d_in:.1f}x")

## 2. Extract Features from Text

Let's see what features activate for a sample text.

In [None]:
# Sample text
text = "The Golden Gate Bridge is a famous landmark in San Francisco."

# Tokenize
tokens = model.to_tokens(text)
print(f"Tokens: {[model.to_string(t) for t in tokens[0]]}")

In [None]:
# Run the model and get activations at layer 8
_, cache = model.run_with_cache(tokens)
activations = cache["blocks.8.hook_resid_pre"]

print(f"Activation shape: {activations.shape}")
print(f"  (batch, sequence_length, d_model)")

In [None]:
# Pass activations through the SAE to get feature activations
feature_acts = sae.encode(activations)

print(f"Feature activation shape: {feature_acts.shape}")
print(f"  (batch, sequence_length, n_features)")

In [None]:
# How sparse are the activations?
nonzero_per_position = (feature_acts[0] > 0).sum(dim=-1).float()
total_features = feature_acts.shape[-1]

print(f"\nSparsity analysis:")
print(f"  Total features: {total_features}")
print(f"  Avg active per position: {nonzero_per_position.mean():.1f}")
print(f"  Sparsity: {100 * (1 - nonzero_per_position.mean() / total_features):.2f}%")

## 3. Explore Individual Features

Let's look at which features activate most strongly for each token.

In [None]:
# For each token position, find the top activating features
token_strs = [model.to_string(t) for t in tokens[0]]

print("Top 3 features per token:\n")
for pos, token_str in enumerate(token_strs):
    acts = feature_acts[0, pos]
    top_features = torch.topk(acts, k=3)
    
    feature_info = [f"#{idx.item()}({val:.2f})" 
                    for idx, val in zip(top_features.indices, top_features.values)]
    print(f"{token_str:15} → {', '.join(feature_info)}")

In [None]:
# Find the single most active feature across all positions
max_act = feature_acts[0].max()
max_pos = (feature_acts[0] == max_act).nonzero()[0]
max_feature_idx = max_pos[1].item()
max_token_pos = max_pos[0].item()

print(f"Most active feature: #{max_feature_idx}")
print(f"  Activation: {max_act:.2f}")
print(f"  Token: '{token_strs[max_token_pos]}' (position {max_token_pos})")

## 4. Investigate a Specific Feature

Let's test what a specific feature responds to by running it on multiple texts.

In [None]:
def get_feature_activation(text, feature_idx):
    """Get the max activation of a feature for a given text."""
    tokens = model.to_tokens(text)
    _, cache = model.run_with_cache(tokens)
    activations = cache["blocks.8.hook_resid_pre"]
    feature_acts = sae.encode(activations)
    return feature_acts[0, :, feature_idx].max().item()

# Test the most active feature on related and unrelated texts
feature_to_test = max_feature_idx

test_texts = [
    "The Golden Gate Bridge is beautiful.",
    "San Francisco is a city in California.",
    "The bridge spans the bay.",
    "I love eating pizza.",
    "Python is a programming language.",
    "The weather is nice today.",
]

print(f"Testing feature #{feature_to_test}:\n")
for text in test_texts:
    act = get_feature_activation(text, feature_to_test)
    bar = "█" * int(act * 5)
    print(f"{act:5.2f} {bar:20} {text[:50]}")

## 5. Reconstruction Quality

SAEs trade off reconstruction accuracy for sparsity. Let's measure this.

In [None]:
# Get original activations
text = "The capital of France is Paris."
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
original_acts = cache["blocks.8.hook_resid_pre"]

# Encode then decode through SAE
feature_acts = sae.encode(original_acts)
reconstructed_acts = sae.decode(feature_acts)

# Calculate reconstruction error
mse = ((original_acts - reconstructed_acts) ** 2).mean().item()
original_variance = original_acts.var().item()
explained_variance = 1 - mse / original_variance

print(f"Reconstruction quality:")
print(f"  MSE: {mse:.4f}")
print(f"  Explained variance: {explained_variance:.1%}")
print(f"  (Higher is better, typically 85-95% for good SAEs)")

## 6. Steering with Features (Advanced)

The famous "Golden Gate Bridge" experiment showed you can steer model behavior by amplifying features. Let's try a simpler version.

In [None]:
def generate_with_steering(prompt, feature_idx, steering_strength=0.0, max_tokens=20):
    """Generate text with optional feature steering."""
    
    # Get the feature direction
    feature_direction = sae.W_dec[feature_idx]  # Shape: (d_model,)
    
    def steering_hook(activation, hook):
        # Add the feature direction scaled by steering strength
        activation[:, :, :] += steering_strength * feature_direction
        return activation
    
    # Generate with hook
    if steering_strength != 0:
        with model.hooks(fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]):
            output = model.generate(
                prompt, 
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True
            )
    else:
        output = model.generate(
            prompt,
            max_new_tokens=max_tokens, 
            temperature=0.7,
            do_sample=True
        )
    
    return model.to_string(output[0])

In [None]:
# Find a feature that might be interesting to steer with
# Let's look for a feature that activates on a specific concept

prompt = "I went to the store to buy"

print("Generation WITHOUT steering:")
for i in range(3):
    result = generate_with_steering(prompt, feature_idx=0, steering_strength=0.0)
    print(f"  {result}")

print(f"\nGeneration WITH steering (feature #{max_feature_idx}, strength=5.0):")
for i in range(3):
    result = generate_with_steering(prompt, feature_idx=max_feature_idx, steering_strength=5.0)
    print(f"  {result}")

## 7. Explore on Neuronpedia

For any feature you find interesting, you can look it up on Neuronpedia to see:
- Max activating examples
- Feature descriptions
- Related features

In [None]:
def neuronpedia_url(feature_idx, layer=8):
    """Generate a Neuronpedia URL for a feature."""
    return f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feature_idx}"

print(f"Explore feature #{max_feature_idx} on Neuronpedia:")
print(neuronpedia_url(max_feature_idx))

## Exercises

Try these to deepen your understanding:

### Exercise 1: Find a "code" feature
Run the SAE on programming-related text and find a feature that activates strongly for code concepts.

### Exercise 2: Compare layers
Load SAEs for different layers (e.g., layer 0 vs layer 10). How do the features differ?

### Exercise 3: Measure sparsity vs reconstruction trade-off
Try different texts and plot the relationship between number of active features and reconstruction quality.

In [None]:
# Exercise 1: Your code here
code_text = "def hello_world():\n    print('Hello, world!')"

# Find top features for this text
# ...

## Summary

In this notebook, you:
1. Loaded a pre-trained SAE for GPT-2 Small
2. Extracted sparse feature activations from text
3. Explored what individual features respond to
4. Measured reconstruction quality
5. Experimented with feature steering

**Next steps:**
- Explore more features on [Neuronpedia](https://www.neuronpedia.org/gpt2-small)
- Try training your own SAE with [SAE Lens](https://github.com/jbloomAus/SAELens)
- Read [Chapter 10: Attribution](https://ttsugriy.github.io/mechinterp-first-principles/chapters/10-attribution.html) to learn how to trace feature contributions to outputs