# SynthSAEBench: Evaluating SAE Architectures on Synthetic Data

This tutorial walks through using SynthSAEBench to train and evaluate SAE architectures on large-scale synthetic data with known ground-truth features.

SynthSAEBench provides:

- A pretrained synthetic model (**SynthSAEBench-16k**) with 16,384 ground-truth features exhibiting realistic properties: Zipfian firing distributions, hierarchical features, correlated firings, and superposition
- A training runner (**SyntheticSAERunner**) for training SAEs at scale with wandb logging and periodic evaluation
- Ground-truth evaluation metrics (MCC, F1, precision, recall) that measure how well an SAE recovers the true underlying features

Unlike LLM benchmarks where the ground truth is unknown, SynthSAEBench lets you precisely diagnose *why* an SAE architecture succeeds or fails.

For background on the synthetic data primitives (FeatureDictionary, ActivationGenerator, etc.), see the [training on synthetic data tutorial](https://github.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb). This tutorial focuses on the large-scale benchmark workflow.

**NOTE:** This notebook requires a GPU to run in a reasonable time. Training on SynthSAEBench-16k takes ~15-20 minutes per SAE on an H100 but will be extremely slow on CPU.

## Setup

In [None]:
import warnings

import torch

try:
    import google.colab  # type: ignore

    COLAB = True
    %pip install sae-lens
except Exception:
    COLAB = False

device = "cuda"
if not torch.cuda.is_available():
    warnings.warn(
        "CUDA is not available. This notebook requires a GPU to run in a reasonable time. "
        "Training on SynthSAEBench-16k takes ~15-20 minutes on an H100 but will be "
        "extremely slow on CPU.",
        stacklevel=1,
    )
    device = "cpu"

## Loading the SynthSAEBench-16k Model

The SynthSAEBench-16k model is the standard benchmark. It has 16,384 ground-truth features in a 768-dimensional hidden space, with Zipfian firing probabilities, hierarchical structure, low-rank correlations, and superposition.

Let's load it and explore its properties.

In [None]:
from sae_lens.synthetic import SyntheticModel

model = SyntheticModel.from_pretrained(
    "decoderesearch/synth-sae-bench-16k-v1",
    device=device,
)

print(f"Number of ground-truth features: {model.cfg.num_features:,}")
print(f"Hidden dimension: {model.cfg.hidden_dim}")
print(f"Superposition ratio: {model.cfg.num_features / model.cfg.hidden_dim:.1f}x")

### Exploring the model

We can sample activations and inspect their statistics.

In [None]:
# Sample hidden activations and ground-truth feature activations
hidden_acts, feature_acts = model.sample_with_features(batch_size=10_000)

print(f"Hidden activations shape: {hidden_acts.shape}")
print(f"Feature activations shape: {feature_acts.shape}")
print(
    f"Hidden activation L2 norm: {hidden_acts.norm(dim=1).mean():.1f}"
    f" (std: {hidden_acts.norm(dim=1).std():.1f})"
)

# L0: average number of active features per sample
l0 = (feature_acts > 0).float().sum(dim=1).mean()
print(f"Average L0 (active features per sample): {l0:.1f}")

### Firing probability distribution

The model uses Zipfian firing probabilities, where a few features fire frequently and most fire rarely.

In [None]:
import plotly.express as px

# Estimate firing frequencies from samples
firing_freqs = (feature_acts > 0).float().mean(dim=0).cpu()

fig = px.histogram(
    x=firing_freqs.numpy(),
    nbins=50,
    log_y=True,
    title="Feature firing frequency distribution (SynthSAEBench-16k)",
    labels={"x": "Firing frequency", "y": "Feature count"},
)
fig.show()

## Training a BatchTopK SAE

Now let's train a BatchTopK SAE on SynthSAEBench-16k using the `SyntheticSAERunner`. We recommend:

- **Width 4096**: In practice, SAEs are narrower than the true number of features
- **200M training samples** with batch size 1024 and LR 3e-4
- **k=25**: A reasonable sparsity target for this model

Training takes about 15-20 minutes on an H100 GPU.

In [None]:
from sae_lens.synthetic import SyntheticSAERunner, SyntheticSAERunnerConfig
from sae_lens import BatchTopKTrainingSAEConfig, LoggingConfig

runner_cfg = SyntheticSAERunnerConfig(
    synthetic_model="decoderesearch/synth-sae-bench-16k-v1",
    sae=BatchTopKTrainingSAEConfig(
        d_in=768,
        d_sae=4096,
        k=25,
    ),
    training_samples=200_000_000,
    batch_size=1024,
    lr=3e-4,
    eval_frequency=1000,
    eval_samples=500_000,
    autocast_sae=True,
    autocast_data=True,
    logger=LoggingConfig(log_to_wandb=False),
    device=device,
)

runner = SyntheticSAERunner(runner_cfg)
btk_result = runner.run()

## Evaluating the trained SAE

The runner automatically runs a final evaluation against the ground-truth features. Let's look at the results.

In [None]:
eval_result = btk_result.final_eval
assert eval_result is not None

print("BatchTopK SAE Results:")
print(f"  Explained variance (R\u00b2): {eval_result.explained_variance:.4f}")
print(f"  MCC:                     {eval_result.mcc:.4f}")
print(f"  Uniqueness:              {eval_result.uniqueness:.4f}")
print(f"  F1:                      {eval_result.classification.f1_score:.4f}")
print(f"  Precision:               {eval_result.classification.precision:.4f}")
print(f"  Recall:                  {eval_result.classification.recall:.4f}")
print(f"  SAE L0:                  {eval_result.sae_l0:.1f}")
print(f"  True L0:                 {eval_result.true_l0:.1f}")
print(f"  Dead latents:            {eval_result.dead_latents}")
print(f"  Shrinkage:               {eval_result.shrinkage:.4f}")

### Understanding the metrics

- **Explained variance (R²)**: How well the SAE reconstructs inputs. 1.0 = perfect reconstruction.
- **MCC**: Mean Correlation Coefficient — measures alignment between SAE decoder columns and ground-truth feature vectors via optimal bipartite matching. 1.0 = perfect feature recovery.
- **Uniqueness**: Fraction of SAE latents that map to distinct ground-truth features. Low uniqueness means multiple latents represent the same feature.
- **F1 / Precision / Recall**: Each SAE latent is treated as a binary classifier for its best-matching ground-truth feature. Precision measures false positive rate; recall measures false negative rate.
- **L0**: Average active latents per sample. Compare SAE L0 to true L0.
- **Dead latents**: Latents that never activate — wasted capacity.
- **Shrinkage**: Ratio of output to input norm. Values below 1.0 mean the SAE systematically reduces activation magnitudes.

## Comparing SAE Architectures

One of the key uses of SynthSAEBench is comparing different SAE architectures. Let's train a Matryoshka BatchTopK SAE and a Standard L1 SAE for comparison.

### Matryoshka BatchTopK SAE

Matryoshka SAEs use nested reconstruction losses at multiple widths to encourage better latent quality and reduce feature absorption.

In [None]:
from sae_lens import MatryoshkaBatchTopKTrainingSAEConfig

matryoshka_cfg = SyntheticSAERunnerConfig(
    synthetic_model="decoderesearch/synth-sae-bench-16k-v1",
    sae=MatryoshkaBatchTopKTrainingSAEConfig(
        d_in=768,
        d_sae=4096,
        k=25,
        matryoshka_widths=[128, 512, 2048],
    ),
    training_samples=200_000_000,
    batch_size=1024,
    lr=3e-4,
    eval_frequency=1000,
    eval_samples=500_000,
    autocast_sae=True,
    autocast_data=True,
    logger=LoggingConfig(log_to_wandb=False),
    device=device,
)

matryoshka_result = SyntheticSAERunner(matryoshka_cfg).run()

### Standard L1 SAE

In [None]:
from sae_lens import StandardTrainingSAEConfig

standard_cfg = SyntheticSAERunnerConfig(
    synthetic_model="decoderesearch/synth-sae-bench-16k-v1",
    sae=StandardTrainingSAEConfig(
        d_in=768,
        d_sae=4096,
        l1_coefficient=2.0,  # should result in L0 around 20-25
        l1_warm_up_steps=10_000,
    ),
    training_samples=200_000_000,
    batch_size=1024,
    lr=3e-4,
    eval_frequency=1000,
    eval_samples=500_000,
    autocast_sae=True,
    autocast_data=True,
    logger=LoggingConfig(log_to_wandb=False),
    device=device,
)

standard_result = SyntheticSAERunner(standard_cfg).run()

### Comparing results

In [None]:
import pandas as pd

results = {
    "BatchTopK": btk_result.final_eval,
    "Matryoshka BTK": matryoshka_result.final_eval,
    "Standard L1": standard_result.final_eval,
}

rows = []
for name, ev in results.items():
    assert ev is not None
    rows.append(
        {
            "Architecture": name,
            "R\u00b2": f"{ev.explained_variance:.4f}",
            "MCC": f"{ev.mcc:.4f}",
            "Uniqueness": f"{ev.uniqueness:.4f}",
            "F1": f"{ev.classification.f1_score:.4f}",
            "Precision": f"{ev.classification.precision:.4f}",
            "Recall": f"{ev.classification.recall:.4f}",
            "SAE L0": f"{ev.sae_l0:.1f}",
            "Shrinkage": f"{ev.shrinkage:.4f}",
            "Dead": ev.dead_latents,
        }
    )

df = pd.DataFrame(rows)
df

Some patterns you may observe (consistent with the SynthSAEBench paper):

- **Matryoshka SAEs** tend to have the best MCC and F1 (latent quality) despite lower explained variance (reconstruction)
- **Standard L1 SAEs** suffer from shrinkage, where the SAE systematically reduces activation magnitudes
- No architecture achieves perfect F1, reproducing the known gap between SAE probing and supervised probing seen in LLM SAE evaluation

## Standalone evaluation

You can also evaluate any SAE against the synthetic model outside of the training runner using `eval_sae_on_synthetic_data`.

In [None]:
from sae_lens.synthetic import eval_sae_on_synthetic_data

eval_result = eval_sae_on_synthetic_data(
    sae=btk_result.sae,
    feature_dict=model.feature_dict,
    activations_generator=model.activation_generator,
    num_samples=500_000,
    batch_size=1024,
)

print(f"MCC: {eval_result.mcc:.4f}")
print(f"F1:  {eval_result.classification.f1_score:.4f}")

## Creating custom benchmark models

You can create custom synthetic models for ablation studies. For example, to study the effect of superposition, vary the hidden dimension while keeping everything else fixed.

In [None]:
from sae_lens.synthetic import (
    SyntheticModelConfig,
    ZipfianFiringProbabilityConfig,
    OrthogonalizationConfig,
)

# A smaller, faster model for quick experiments
small_cfg = SyntheticModelConfig(
    num_features=1024,
    hidden_dim=256,
    firing_probability=ZipfianFiringProbabilityConfig(
        exponent=0.5,
        max_prob=0.4,
        min_prob=5e-4,
    ),
    orthogonalization=OrthogonalizationConfig(num_steps=100, lr=3e-4),
    seed=42,
)

# Train directly on a config (creates a temporary model)
small_runner_cfg = SyntheticSAERunnerConfig(
    synthetic_model=small_cfg,
    sae=BatchTopKTrainingSAEConfig(
        d_in=256,
        d_sae=512,
        k=10,
    ),
    training_samples=10_000_000,
    batch_size=1024,
    lr=3e-4,
    logger=LoggingConfig(log_to_wandb=False),
    device=device,
)

small_result = SyntheticSAERunner(small_runner_cfg).run()

assert small_result.final_eval is not None
print(f"MCC: {small_result.final_eval.mcc:.4f}")
print(f"F1:  {small_result.final_eval.classification.f1_score:.4f}")

## Logging to Weights & Biases

To enable wandb logging, pass a `LoggingConfig` with `log_to_wandb=True`. Training loss, evaluation metrics (MCC, F1, etc.), and other diagnostics will be logged automatically.

```python
from sae_lens import LoggingConfig

runner_cfg = SyntheticSAERunnerConfig(
    ...,
    logger=LoggingConfig(
        log_to_wandb=True,
        wandb_project="synth-sae-bench",
        wandb_entity="my-team",  # Optional
        run_name="batchtopk-k25",  # Auto-generated if not set
        wandb_log_frequency=10,
    ),
)
```

## Summary

In this tutorial we covered:

1. **Loading SynthSAEBench-16k** and exploring its properties
2. **Training SAEs** using `SyntheticSAERunner` with BatchTopK, Matryoshka, and Standard L1 architectures
3. **Evaluating SAEs** with ground-truth metrics (MCC, F1, precision, recall, explained variance)
4. **Comparing architectures** to understand their trade-offs
5. **Creating custom models** for ablation studies

### Next steps

- Try other architectures: `JumpReLUTrainingSAEConfig`, `MatchingPursuitTrainingSAEConfig`
- Sweep L0 values to observe the precision-recall trade-off
- Create custom models with different hierarchy depths, correlation strengths, or superposition levels
- See the [synthetic data docs](https://decoderesearch.github.io/SAELens/synthetic_data/) for the full API reference
- See the [SynthSAEBench docs](https://decoderesearch.github.io/SAELens/synth_sae_bench/) for benchmark details and recommended settings
- See the [SynthSAEBench paper](https://arxiv.org/abs/2602.14687) for more details on the synthetic data primitives and the benchmark results