# SAE Training

This notebook is used to train a Sparse Autoencoder (SAE) on the OpenR1-Math-220k dataset.

In [1]:
from transformer_lens import HookedTransformer
import torch
import circuitsvis as cv
import einops
from IPython.display import Image, display
import numpy as np
from pprint import pprint
import pandas as pd
from tqdm import tqdm
import wandb
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens import SAETrainingRunner, SAE, HookedSAETransformer, TrainingSAE, ActivationsStore
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from datasets import load_dataset, Dataset, Features, Value
from utils import *
import json
import os
import plotly.express as px

  from .autonotebook import tqdm as notebook_tqdm
2025-03-01 05:43:47.839368: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740807827.857863  382456 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740807827.863417  382456 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7611ed0d60b0>

In [4]:
model = HookedSAETransformer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = model.to(device)
model.cfg.n_ctx = 2048

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.




Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B into HookedTransformer
Moving model to device:  cuda


In [None]:
create_balanced_backtracking_dataset(
    json_file_paths=["math_cot_results_t=0.6_mnt=1500_tp=0.92.json", "math_cot_results_t=0.7_mnt=1800_tp=0.92.json"],
    output_path="backtracking_dataset_n=1000.json",
    n=1000,
    seed=42
)

In [None]:
output_path = "openr1_math_backtracking_dataset.json"
stats = create_openr1_backtracking_dataset(output_path, backtracking_phrases)
print(stats)

I uploaded this dataset to `uzaymacar/openr1_math_backtracking_dataset` in HuggingFace.

In [5]:
# Load the dataset from the JSON file
path = "openr1_math_backtracking_dataset_hf"

if os.path.exists(path):  
    dataset = Dataset.load_from_disk(path)
else:
    with open("openr1_math_backtracking_dataset.json", "r") as f: dataset_json = json.load(f)
    # Filter to keep only the three specified fields
    filtered_dataset = []
    for item in dataset_json:
        filtered_item = {'uuid': item['uuid'], 'has_backtracking': item['has_backtracking'], 'text': item['text']}
        filtered_dataset.append(filtered_item)

    dataset = Dataset.from_list(filtered_dataset)
    dataset.save_to_disk(path)

print(f"Got dataset with {len(dataset)} examples")

Got dataset with 193767 examples


In [13]:
with open("backtracking_dataset_n=1000.json", "r") as f: examples = json.load(f)
examples = [x for x in examples if x["has_backtracking"] and x["is_correct"] and "i think i made a mistake" in x["generated_cot"].lower()]
examples.sort(key=lambda x: len(x["generated_cot"]))
example = examples[0]["generated_cot"]

In [6]:
# Configure SAE training
# Adjust these parameters based on your computational resources
total_training_steps = 10000 # Reduce if needed for faster training
batch_size = 2048  # Adjust based on your GPU memory
total_training_tokens = total_training_steps * batch_size

# Learning rate schedule parameters
lr_warm_up_steps = total_training_steps // 20  # 5% of training
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 10  # 10% of training

# Create the SAE configuration
cfg = LanguageModelSAERunnerConfig(
    # Model configuration
    model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    hook_name="blocks.21.hook_mlp_out", 
    hook_layer=21,  
    d_in=1536, 
    
    # Dataset configuration
    dataset_path="open-r1/OpenR1-Math-220k",  # We will override this in the training runner
    streaming=True,  # Stream the dataset to save memory
    
    # SAE Parameters
    mse_loss_normalization=None,
    expansion_factor=16,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=False,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    
    # Training Parameters
    lr=3e-5, # NEW
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="cosineannealingwarmrestarts", # NEW, before it was "constant"
    lr_warm_up_steps=lr_warm_up_steps,
    lr_decay_steps=lr_decay_steps,
    l1_coefficient=4.0,  # Controls sparsity
    l1_warm_up_steps=l1_warm_up_steps,
    lp_norm=1.0,
    train_batch_size_tokens=batch_size,
    context_size=512, 
    activation_fn='relu',
    prepend_bos=False,
    
    # Activation Store Parameters
    n_batches_in_buffer=32,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=8,
    
    # Resampling protocol
    use_ghost_grads=False,
    feature_sampling_window=1000,
    dead_feature_window=1000,
    dead_feature_threshold=1e-4,
    
    # WANDB configuration
    log_to_wandb=True,
    wandb_project="deepseek_sae",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    
    # Misc
    device=device,
    seed=42,
    n_checkpoints=2,  # Save 2 checkpoints during training
    checkpoint_path="deepseek_sae_checkpoints",
    dtype="float32",  # Use bfloat16 if you have memory issues: "bfloat16"
)

In [None]:
# Train the SAE
sae = SAETrainingRunner(cfg, override_dataset=dataset).run()

In [7]:
# s66sea6t/final_40960000 => layer 7 mlp
# 09h5vx03/final_40960000 => layer 21 mlp
# g2jkk2t9/14204928 => layer 14 mlp
# k2j8l3a8/3366912 => layer 25 attention z
# mau5uadu/5121024 => layer 21 resid post
# a264vb4e/final_20480000 => layer 25 mlp
path = "deepseek_sae_checkpoints/09h5vx03/final_40960000"
sae = TrainingSAE(cfg).load_from_pretrained(path, device=device)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [8]:
batch_size = 1024
tokens_list = []

# Get total dataset size
dataset_size = len(dataset)

# Generate random indices
random_indices = random.sample(range(dataset_size), batch_size)

# Sample randomly from the dataset
for idx in random_indices:
    x = dataset[idx]
    # Get tokens for this example
    token_tensor = model.to_tokens(x["text"])
    tokens_list.append(token_tensor)

# Find max length for padding
max_length = max(t.shape[1] for t in tokens_list)

# Pad and stack
padded_tokens = []
for t in tokens_list:
    # Pad with zeros if needed
    if t.shape[1] < max_length:
        padding = torch.zeros((1, max_length - t.shape[1]), dtype=t.dtype, device=t.device)
        t_padded = torch.cat([t, padding], dim=1)
    else:
        t_padded = t
    padded_tokens.append(t_padded)

tokens = torch.stack(padded_tokens, dim=0).to(device).view(batch_size, -1)
print(tokens.shape)

torch.Size([1024, 2047])


In [16]:
sae_vis_data = SaeVisData.create(
    sae=sae,
    model=model,
    tokens=tokens,
    cfg=SaeVisConfig(features=list(range(64)) + [22424, 5349, 6685, 6041, 1607, 7976], minibatch_size_tokens=16),
    clear_memory_between_batches=True,
    verbose=True,
)
sae_vis_data.save_feature_centric_vis(filename=str("deepseek_sae_data_math_layer_21_mlp_features_64.html"), verbose=True)
torch.cuda.empty_cache()

Forward passes to cache data for vis:   0%|          | 0/64 [00:00<?, ?it/s]

Forward passes to cache data for vis: 100%|██████████| 64/64 [07:28<00:00,  6.99s/it]

Forward passes to cache data for vis: 100%|██████████| 64/64 [07:50<00:00,  7.36s/it]
Extracting vis data from cached data: 100%|██████████| 70/70 [07:50<00:00,  6.73s/it]
Saving feature-centric vis: 100%|██████████| 70/70 [00:02<00:00, 25.76it/s]


In [17]:
sae_vis_data.save_prompt_centric_vis(
    filename=str("deepseek_sae_data_math_layer_21_mlp_features_64_prompt_centric.html"), 
    # prompt="Solve this math problem step by step. Put your final answer in \\boxed{}. Problem: Find the derivative of f(x) = x^3 - 4x^2 + 5x - 2. Solution: \n<think>\nTo find the derivative, I'll apply the power rule to each term.\nFor x^3, the derivative is 3x^2.\nFor -4x^2, the derivative is -4(2x) = -8x.\nFor 5x, the derivative is 5.\nFor -2, the derivative is 0.\nWait, I think I went wrong with the second term. Let me double-check. For -4x^2, applying the power rule gives -4 times 2x, which is -8x. So I was correct.\nCombining all terms: f'(x) = 3x^2 - 8x + 5.\n</think>\n\nTo find the derivative of f(x) = x^3 - 4x^2 + 5x - 2, I'll apply the power rule to each term.\n\nFor x^3: The derivative is 3x^2\nFor -4x^2: The derivative is -4(2x) = -8x\nFor 5x: The derivative is 5\nFor -2: The derivative is 0\n\nCombining all terms: f'(x) = 3x^2 - 8x + 5\n\nTherefore, \\boxed{f'(x) = 3x^2 - 8x + 5}",
    prompt=examples[1]["generated_cot"],
    num_top_features=100,
    verbose=True
)

Saving feature-centric vis: 100%|██████████| 70/70 [00:08<00:00,  8.32it/s]


In [18]:
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)



In [8]:
# After training, analyze the SAE features
# Get the projection onto the unembedding matrix
projection_onto_unembed = sae.W_dec @ model.W_U

# Get the top 10 logits for each feature
vals, inds = torch.topk(projection_onto_unembed, 10, dim=1)

# Get 10 random features to examine
random_indices = torch.arange(projection_onto_unembed.shape[0]) # torch.randint(0, projection_onto_unembed.shape[0], (10,)) # torch.tensor([19])

# Show the top 10 logits promoted by those features
top_10_logits_df = pd.DataFrame([model.to_str_tokens(i) for i in inds[random_indices]], index=random_indices.tolist()).T
print("Top tokens for random features:")
display(top_10_logits_df)

Top tokens for random features:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,24566,24567,24568,24569,24570,24571,24572,24573,24574,24575
0,able,diameter,either,widest,factors,center,zero,get,ille,value,...,even,cases,to,stream,close,half,confused,least,when,m
1,provided,Diameter,Either,_F,Factors,centers,zero,Get,So,值,...,Even,Case,待,stay,closest,paid,final,TO,When,t
2,diss,直径,Either,F,Factors,Center,0,gets,so,value,...,Even,case,to,di,closer,halfway,conf,Least,when,j
3,希望,diam,OR,A,factor,_center,义务,shift,so,Value,...,even,Cases,yet,-MM,distance,DECLARE,confusion,someone,When,a
4,present,λ,either,formed,因子,corners,Cathy,Get,hello,larger,...,—even,两种,To,stream,difference,urt,",",Sterling,WHEN,Hemisphere
5,supposed,Machine,or,k,isors,center,thereof,GET,So,_value,...,despite,Case,_to,Types,clos,WAYS,blood,auce,当,LO
6,included,λ,或者,F,Its,vertices,Som,hurry,heim,-value,...,_even,case,To,-stream,Close,ondo,mixed,buffers,whenever,g
7,found,lambda,或,两年前,Factor,centered,egot,Promise,audience,Value,...,ext,cases,要,Browse,close,synth,conf,ripe,当他,可行
8,contained,waves,或,formation,受益,centre,anytime,Helping,ateral,largest,...,支出,possibilities,forthcoming,dba,Close,明年,Conf,everyone,"""When",v
9,allowed,rotates,nor,Aleppo,_factors,Center,@media,实验,asse,values,...,Inf,depending,值得,~/.,差距,satisfying,strongly,Di,在我,Themes


In [10]:
for top_feature in top_10_logits_df.columns:
    top_logits = top_10_logits_df[top_feature].tolist()
    if 'wait' in top_logits:
        print(top_feature, top_logits)

14759 ['一处', 'edom', ' widths', 'interfaces', '-pack', 'wait', 'cing', ' nutrition', ' redo', '-xs']


In [16]:
# Visualize feature activations on a sample text
sample_text = "<think> 2 + 4 = 5. Wait, I think I made a mistake. 2 + 4 = 6 </think>"
tokens = model.to_tokens(sample_text)
_, cache = model.run_with_cache(sample_text)

# Get activations from the hook point we trained on
activations = cache[f"blocks.{sae.cfg.hook_layer}.hook_mlp_out"]

# Encode the activations with our SAE
feature_activations = sae.encode(activations.reshape(-1, cfg.d_in))
feature_activations = feature_activations.reshape(activations.shape[0], activations.shape[1], -1)

# Find the most active features
feature_importance = feature_activations.abs().mean(dim=(0, 1))
# top_features = torch.topk(feature_importance, 100).indices
top_features = [19, 22424, 5349, 6685]

print("\nMost important features and their top tokens:")
for i, feature_idx in enumerate(top_features):
    top_tokens = model.to_str_tokens(inds[feature_idx][:10])
    print(f"Feature {feature_idx}: {', '.join(top_tokens)}")

del cache 
torch.cuda.empty_cache()


Most important features and their top tokens:
Feature 19: ,,  not, :, d,  moment,  off,  in,  but,  when, ,\
Feature 22424: ,, :,  but, d,  not,  moment, a,  in,  this,  by
Feature 5349: ,, 这点, 这一点,  del, hang,  /,  moment,  tongue, d,  hold
Feature 6685: ,, 侯,  moment,  null,  those,  hard,  Taylor,  not,  off,  ellipt


In [20]:
# finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens
feature_list = torch.arange(sae.cfg.d_sae)
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 1
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))

for i in pbar:
    tokens = activation_store.get_batch_tokens()
    tokens_df = make_token_df(tokens, len_prefix=5, len_suffix=3, model=model)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    _, cache = model.run_with_cache(tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name])
    sae_in = cache[sae.cfg.hook_name]
    feature_acts = sae.encode(sae_in).squeeze()

    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    pbar.set_description(f"Examples found: {examples_found}")
    
    del cache 
    torch.cuda.empty_cache()

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

  0%|          | 0/1 [00:00<?, ?it/s]

Examples found: 4096: 100%|██████████| 1/1 [00:01<00:00,  1.03s/it]


In [21]:
feature_acts_df = pd.DataFrame(all_feature_acts.detach().cpu().numpy(), columns=[f"feature_{i}" for i in feature_list])
feature_acts_df.shape

(4096, 24576)

In [None]:
feature_idx = 0
all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx].detach()
prop_positive_activations = (100 * len(all_positive_acts) / (total_batches * batch_size_tokens))

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,
)

In [28]:
top_10_activations = feature_acts_df.sort_values(f"feature_{feature_list[6685]}", ascending=False)
all_token_dfs.iloc[top_10_activations.index]

Unnamed: 0,str_tokens,unique_token,context,prompt,pos,label,batch
1103,weight,weight/79,"the bucket, and the| weight| of the bucket",2,79,2/79,0
0,0,0/0,|0|\right\,0,0,0/0,0
2736,}},}}/176,Similar triangles (other)| }}|\end{,5,176,5/176,0
2723,Cos,Cos/163,}\text { Law of| Cos|ines } \\,5,163,5/163,0
2724,ines,ines/164,text { Law of Cos|ines| } \\ {,5,164,5/164,0
...,...,...,...,...,...,...,...
1371,(C,(C/347,2}<3 t$\n|(C|) $3,2,347,2/347,0
1372,),)/348,}<3 t$\n(C|)| $3 t,2,348,2/348,0
1373,$,$/349,3 t$\n(C)| $|3 t<s,2,349,2/349,0
1374,3,3/350,t$\n(C) $|3| t<s^{,2,350,2/350,0


In [None]:
# 2346, 4614, 10611
for i in tqdm(range(feature_list.shape[0]), desc="Finding features with backtracking phrases"):
    top_10_activations = feature_acts_df.sort_values(f"feature_{feature_list[i]}", ascending=False)
    tokens_we_want_df = all_token_dfs.iloc[top_10_activations.index]
    strings = [x.lower() for x in tokens_we_want_df['str_tokens'].tolist()]
    if any(x.lower() in strings[0:10] for x in ["wait", "mistake", "incorrect", "redo", "hold"]):
        print(strings[0:10])