# Training SAEs with SAELens on Gemma 3

### Import libraries and detect hardware

In [11]:
import torch
import os
import gc
from transformers import AutoTokenizer, Gemma3ForCausalLM
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


### Load Gemma model checkpoint from huggingface

In [9]:
ckpt = "google/gemma-3-1b-pt"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = Gemma3ForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

### Train the SAE

In [10]:
total_training_steps = 1000
batch_size = 1024
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5
l1_warm_up_steps = total_training_steps // 20

cfg = LanguageModelSAERunnerConfig(
    # SAE architecture and model
    architecture="topk",  # topk used since matryoshka batch topk converts to jumprelu
    activation_fn="topk",
    activation_fn_kwargs={"k": 40},  # important: this is where 'k' goes
    model_name="google/gemma-3-1b-pt", # lol it wasn't working for so long because it didn't have google in front of it 
    model_class_name="AutoModelForCausalLM",
    hook_name="model.layers.0",
    hook_layer=0,
    d_in=1152,
    d_sae=32768,

    # Dataset
    dataset_path="apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b",
    dataset_trust_remote_code=True,
    streaming=True,
    context_size=2048,
    prepend_bos=True,

    # Training
    lr=3e-4,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,
    lr_decay_steps=lr_decay_steps,
    l1_coefficient=0.001,
    l1_warm_up_steps=l1_warm_up_steps,
    lp_norm=1.0,
    train_batch_size_tokens=batch_size,
    training_tokens=total_training_tokens,
    n_batches_in_buffer=4,
    store_batch_size_prompts=2,

    # Init + Heuristics
    apply_b_dec_to_input=True,
    decoder_orthogonal_init=False,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_sae_decoder=False,
    normalize_activations="none",
    exclude_special_tokens=True,

    # Logging
    log_to_wandb=True,
    wandb_project="sae_gemma3_experiment",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,

    # Misc
    device="cuda",
    act_store_device="with_model",
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    verbose=True,

    # Stability
    use_ghost_grads=False,
    feature_sampling_window=1000,
    dead_feature_window=1000,
    dead_feature_threshold=1e-4
)

sparse_autoencoder = SAETrainingRunner(cfg).run()


Downloading readme: 100%|██████████| 303/303 [00:00<00:00, 6.61kB/s]
Objective value: 2408587.0000:   6%|▌         | 6/100 [00:00<00:00, 652.67it/s]
  out = torch.tensor(origin, dtype=self.dtype, device=self.device)
[34m[1mwandb[0m: Currently logged in as: [33mhtsai[0m ([33mhtsai2025[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


1000| auxiliary_reconstruction_loss: 0.00000 | mse_loss: 6400.62744: 100%|██████████| 1024000/1024000 [03:22<00:00, 5046.94it/s]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
details/current_l1_coefficient,▁████████████████████████████████
details/current_learning_rate,███████████████████████████▇▅▄▃▂▁
details/n_training_tokens,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▆▅▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▆▅▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance,▁▃▄▆▆▇▇▇▇▇▇██████████████████████
metrics/explained_variance_legacy,▁▃▅▆▇▇▇▇▇▇▇▇█████████████████████
metrics/explained_variance_legacy_std,▃▆███▇▆▅▄▄▃▃▃▃▂▃▂▂▂▁▂▂▁▂▂▁▁▁▁▁▁▂▁
metrics/l0,▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient,0.001
details/current_learning_rate,2e-05
details/n_training_tokens,1013760.0
losses/auxiliary_reconstruction_loss,0.0
losses/mse_loss,6181.89844
losses/overall_loss,6181.89844
metrics/explained_variance,0.9318
metrics/explained_variance_legacy,0.93539
metrics/explained_variance_legacy_std,0.0778
metrics/l0,40.0


### Clean up the GPU memory

In [12]:
del model  # or whatever large object you want to remove
gc.collect()
torch.cuda.empty_cache()