# 03 – Train a Sparse Autoencoder

**Purpose:** Feature discovery. This is where SAEs stop being abstract.

We train a sparse autoencoder on the cached MLP activations to learn interpretable features.

**Sections:**
1. Load cached activations
2. Initialize SAE
3. Train for a short run
4. Plot: reconstruction loss vs step, sparsity vs step
5. Save model

---

## What is an SAE?

A **Sparse Autoencoder** learns to:
- **Encode**: Map a 768-dim activation → 4096-dim sparse representation
- **Decode**: Reconstruct the original activation from the sparse code

The L1 penalty encourages sparsity — most features should be ~0 for any given input.

**Why 4096 features for 768 dims?**  
Overcomplete dictionaries can capture more fine-grained, interpretable directions than the original basis.

---

## 1. Load cached activations

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
import matplotlib.pyplot as plt

os.chdir('/Users/poonam/projects/mechinterp-from-scratch')

# Config
CACHE_PATH = "artifacts/cache/gpt2_l6_mlpout_fp16.mmap"
D_IN = 768
D_SAE = 4096
LR = 1e-3
L1_COEF = 1e-3

# Load memmap
data = np.memmap(CACHE_PATH, dtype=np.float16, mode="r")
n_tokens = data.shape[0] // D_IN
data = data.reshape(n_tokens, D_IN)

print(f"Loaded {n_tokens:,} activation vectors")
print(f"Shape: {data.shape}")

---

## 2. Initialize SAE

In [None]:
class SAE(nn.Module):
    """Simple Sparse Autoencoder.
    
    Architecture:
        encode: x → z = W_enc @ x + b_enc
        activate: a = ReLU(z)  (sparse activations)
        decode: x_hat = W_dec @ a
    """
    def __init__(self, d_in: int, d_sae: int):
        super().__init__()
        self.enc = nn.Linear(d_in, d_sae, bias=True)
        self.dec = nn.Linear(d_sae, d_in, bias=False)
        
    def forward(self, x):
        z = self.enc(x)
        a = torch.relu(z)  # sparse activations
        x_hat = self.dec(a)
        return x_hat, a

device = torch.device("cpu")  # CPU is fine for small SAE training
sae = SAE(D_IN, D_SAE).to(device)
optimizer = AdamW(sae.parameters(), lr=LR)

print(f"SAE: {D_IN} → {D_SAE} → {D_IN}")
print(f"Parameters: {sum(p.numel() for p in sae.parameters()):,}")

---

## 3. Train

In [None]:
def sample_batch(batch_size: int) -> torch.Tensor:
    """Sample random batch of activations."""
    idx = np.random.randint(0, n_tokens, size=(batch_size,))
    x = torch.from_numpy(data[idx].astype(np.float32))
    return x.to(device)

In [None]:
# Training config
N_STEPS = 2000
BATCH_SIZE = 2048
LOG_EVERY = 100

# Tracking
history = {
    "step": [],
    "recon_loss": [],
    "sparsity": [],
    "total_loss": []
}

print(f"Training for {N_STEPS} steps, batch size {BATCH_SIZE}")
print(f"L1 coefficient: {L1_COEF}")
print("-" * 60)

for step in range(N_STEPS):
    x = sample_batch(BATCH_SIZE)
    x_hat, a = sae(x)
    
    # Losses
    recon_loss = torch.mean((x_hat - x) ** 2)
    sparsity = torch.mean(torch.abs(a))
    total_loss = recon_loss + L1_COEF * sparsity
    
    # Backward
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # Log
    if step % LOG_EVERY == 0:
        history["step"].append(step)
        history["recon_loss"].append(recon_loss.item())
        history["sparsity"].append(sparsity.item())
        history["total_loss"].append(total_loss.item())
        
        print(f"step {step:4d} | recon: {recon_loss.item():.4f} | sparsity: {sparsity.item():.4f} | total: {total_loss.item():.4f}")

print("-" * 60)
print("Training complete!")

---

## 4. Plot training curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Reconstruction loss
axes[0].plot(history["step"], history["recon_loss"], 'b-', linewidth=2)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Reconstruction Loss (MSE)")
axes[0].set_title("Reconstruction Loss vs Step")
axes[0].grid(True, alpha=0.3)

# Sparsity
axes[1].plot(history["step"], history["sparsity"], 'r-', linewidth=2)
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Mean |activation|")
axes[1].set_title("Sparsity Proxy vs Step")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Check sparsity more directly: what fraction of activations are ~zero?
x_test = sample_batch(1000)
_, a_test = sae(x_test)

# Count near-zero activations (threshold 0.01)
near_zero = (a_test.abs() < 0.01).float().mean().item()
avg_active = (a_test > 0.01).float().sum(dim=1).mean().item()

print(f"Fraction of activations near zero: {near_zero:.1%}")
print(f"Average active features per input: {avg_active:.1f} / {D_SAE}")

---

## 5. Save model

In [None]:
SAVE_PATH = "artifacts/sae/sae.pt"
os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)

torch.save({
    "state_dict": sae.state_dict(),
    "d_in": D_IN,
    "d_sae": D_SAE,
    "l1_coef": L1_COEF,
    "n_steps": N_STEPS,
    "final_recon_loss": history["recon_loss"][-1],
    "final_sparsity": history["sparsity"][-1],
}, SAVE_PATH)

print(f"Saved SAE to {SAVE_PATH}")

---

## Note: Not all features will be interpretable

After training, you might expect every one of the 4096 features to correspond to a clean, interpretable concept. **This is not the case.**

**Why some features remain uninterpretable:**

1. **Polysemanticity residue**: Some features may still capture multiple unrelated concepts.

2. **Noise features**: Some features may primarily capture noise or high-frequency patterns without semantic meaning.

3. **Computational features**: The model may use some directions for intermediate computations that don't map to human concepts.

4. **Insufficient training**: With more data and longer training, more features may become interpretable.

5. **Wrong granularity**: Some concepts may require different scales (finer or coarser) than our SAE provides.

**This is expected and okay.** The goal is to find *some* interpretable features, not to make all features interpretable. Even a small set of interpretable features can provide valuable insights into model behavior.

**Next:** Use notebook 04 to browse features and look for interpretable ones.