# Training SAEs with SAELens on Gemma 3

### Import libraries and detect hardware

In [1]:
import torch
import os
import gc
from transformers import AutoTokenizer, Gemma3ForCausalLM
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


### Load Gemma model checkpoint from huggingface

In [2]:
ckpt = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = Gemma3ForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
from transformers import AutoModelForCausalLM

ckpt = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(ckpt)

for name, module in model.named_modules():
    print(name)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


vision_tower
vision_tower.vision_model
vision_tower.vision_model.embeddings
vision_tower.vision_model.embeddings.patch_embedding
vision_tower.vision_model.embeddings.position_embedding
vision_tower.vision_model.encoder
vision_tower.vision_model.encoder.layers
vision_tower.vision_model.encoder.layers.0
vision_tower.vision_model.encoder.layers.0.layer_norm1
vision_tower.vision_model.encoder.layers.0.self_attn
vision_tower.vision_model.encoder.layers.0.self_attn.k_proj
vision_tower.vision_model.encoder.layers.0.self_attn.v_proj
vision_tower.vision_model.encoder.layers.0.self_attn.q_proj
vision_tower.vision_model.encoder.layers.0.self_attn.out_proj
vision_tower.vision_model.encoder.layers.0.layer_norm2
vision_tower.vision_model.encoder.layers.0.mlp
vision_tower.vision_model.encoder.layers.0.mlp.activation_fn
vision_tower.vision_model.encoder.layers.0.mlp.fc1
vision_tower.vision_model.encoder.layers.0.mlp.fc2
vision_tower.vision_model.encoder.layers.1
vision_tower.vision_model.encoder.laye




### Train the SAE

In [None]:
total_training_steps = 100_000
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5
l1_warm_up_steps = total_training_steps // 20

cfg = LanguageModelSAERunnerConfig(
    # SAE architecture and model
    architecture="topk",  # topk used since matryoshka batch topk converts to jumprelu
    activation_fn="topk",
    activation_fn_kwargs={"k": 40},  # important: this is where 'k' goes
    model_name="google/gemma-3-1b-pt", # lol it wasn't working for so long because it didn't have google in front of it 
    model_class_name="AutoModelForCausalLM",
    hook_name="model.layers.0",
    hook_layer=0,
    d_in=1152,
    d_sae=32768,

    # Dataset
    dataset_path="apollo-research/monology-pile-uncopyrighted-tokenizer-EleutherAI-gpt-neox-20b",
    dataset_trust_remote_code=True,
    streaming=True,
    context_size=2048,
    prepend_bos=True,

    # Training
    lr=3e-4,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,
    lr_decay_steps=lr_decay_steps,
    l1_coefficient=0.001,
    l1_warm_up_steps=l1_warm_up_steps,
    lp_norm=1.0,
    train_batch_size_tokens=batch_size,
    training_tokens=total_training_tokens,
    n_batches_in_buffer=4,
    store_batch_size_prompts=2,

    # Init + Heuristics
    apply_b_dec_to_input=True,
    decoder_orthogonal_init=False,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_sae_decoder=False,
    normalize_activations="none",
    exclude_special_tokens=True,

    # Logging
    log_to_wandb=True,
    wandb_project="sae_gemma3_experiment",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,

    # Misc
    device="cuda",
    act_store_device="with_model",
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    verbose=True,

    # Stability
    use_ghost_grads=False,
    feature_sampling_window=1000,
    dead_feature_window=1000,
    dead_feature_threshold=1e-4
)

sparse_autoencoder = SAETrainingRunner(cfg).run()


Fetching 4 files: 100%|██████████| 4/4 [00:19<00:00,  4.88s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.70it/s]
Refilling buffer:   0%|          | 0/2 [00:00<?, ?it/s]

Layerwise activations cache size 1
Layerwise activations cache:


                                                       

model.layers.0: tensor([[[ 0.0188,  0.0005, -0.0202,  ..., -0.0012, -0.0192,  0.0191],
         [ 0.0258,  0.0164, -0.0135,  ..., -0.0204, -0.0486, -0.0008],
         [ 0.0117,  0.0189, -0.0265,  ..., -0.0045, -0.0053,  0.0067],
         ...,
         [ 0.0049,  0.0018, -0.0070,  ..., -0.0192, -0.0066,  0.0044],
         [-0.0144, -0.0222, -0.0076,  ..., -0.0401, -0.0093, -0.0116],
         [-0.0086, -0.0153, -0.0056,  ..., -0.0029,  0.0086,  0.0098]],

        [[-0.0018,  0.0136,  0.0078,  ..., -0.0133, -0.0328,  0.0077],
         [ 0.0058,  0.0288,  0.0064,  ..., -0.0064, -0.0071,  0.0117],
         [ 0.0225,  0.0172,  0.0118,  ...,  0.0279, -0.0157, -0.0129],
         ...,
         [-0.0065,  0.0066,  0.0042,  ..., -0.0024, -0.0104,  0.0143],
         [ 0.0024, -0.0042,  0.0105,  ..., -0.0060,  0.0004,  0.0030],
         [ 0.0028,  0.0106,  0.0083,  ...,  0.0153, -0.0151,  0.0037]]],
       device='cuda:0')




RuntimeError: The expanded size of the tensor (1152) must match the existing size (4096) at non-singleton dimension 2.  Target sizes: [2, 2048, 1152].  Tensor sizes: [2, 2048, 4096]

### [Deletes the model] Clean up the GPU memory

In [11]:
# del sparse_autoencoder  # Uncomment to delete the trained SAE
gc.collect() # Restart the kernel if there is still leftover memory
torch.cuda.empty_cache()