# SAE Intro Project

The following project and notebook is a simple introduction to the Sparse Auto Encoder (SAE). I do apologise for any rushed or 'malpractice' code, I have finals rn haha.
Anyway, the project is split into two parts:
1. A simple SAE implementation from scratch using Pytorch.
2. A new SAE variant.


## Part 1: Simple SAE Implementation

The following chunk of code is just pre-requisite code for setting up the SAE. This step 1.

In [None]:
# imports and globals setup
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datasets import load_dataset

# The LLM from which we shall build this SAE from
model_name = "EleutherAI/pythia-70m-deduped"

# The dataset we shall use to prompt the SAE while training. Originallt I wanted to use "monology/pile-uncopyrighted"
# But this is too huge and I don't have enough space for this, so I will use a smaller dataset, which makes everything faster,
# plus this is a proof-of-concept and this can be easily scaled.
# dataset_name = "monology/pile-uncopyrighted"
dataset_name = "NeelNanda/pile-10k"

# Store activations here, from the layer which we are interested in
# Might be a better way to do this but this is the simplest
stored_activations = []

# This is the layer we are interested in, in the model.
# NB 0-based index, so 5 is the 6th layer
# Choice of layer is kinda arbitrary, however we want to choose something that is not too deep
# and not too shallow. This is a good trade-off in interpretability plus the Anthropic paper
# suggests that these "middle" layers are a decent choice to start.
chosen_layer = 5

# The maximum number of tokens the input is truncated to.
# This is important for VRAM safety, especially on smaller GPUs
MAX_LEN_TRUNC = 128

# This is the coefficient for the sparsity loss term. This is a pseudo-hyperparameter that we can tune.
# It is used to encourage the SAE to learn a sparse representation of the data, as always, this is kinda
# arbitrary to begin with, however we can change this later. This value seems to work well for now.
ENCOURAGEMENT_COEFF = 1e-3

# This is the number of epochs we shall train the SAE for.
SAE_TRAIN_EPOCHS = 1000

# History trackers
loss_history = []
recon_history = []
sparsity_history = []

# This is the hook function, we shall call this during the forward pass
# So we can  "inspect" the activations of the layer
# by simply storing them in the global activations array
def hook_fn(module, input, output):
    stored_activations.append(output.detach().cpu())

In [None]:
# Load the dataset and training prompts
print("Loading dataset:", dataset_name)
dataset = load_dataset(dataset_name, split="train", keep_in_memory=True)
print("Dataset loaded")
prompts = [example["text"] for example in dataset]
print("Loaded dataset with", len(prompts), "prompts.")

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
print("Loaded model:", model_name)

# Move model to GPU if available, using MPS cause I have a Mac
# Hopefully this should auto-detect which device to use
device = torch.device("mps"
                        if torch.backends.mps.is_available() else "cuda"
                        if torch.cuda.is_available() else "cpu")
print("Using device:", device) # print for sanity

# Move model to device and set to eval mode
model.to(device).eval()
print("Set model to eval mode.")

# Choose a layer — here, MLP from block 5
target_layer = model.gpt_neox.layers[chosen_layer].mlp
hook = target_layer.register_forward_hook(hook_fn)
print("Registered forward hook on layer:", chosen_layer)

Next, we want to tokenize the prompts and run them through the model. Specifically, the goal of the next step
is to run the model in inference mode, and store the activations of the layer we are interested in. This is so that we can
later use these activations to train the SAE. The activations are stored in the `all_activations` variable. We grab this data from the kinda arbitrary layer we chose earlier.

In [None]:
# Loop through all the prompts and tokenize them, saving the required activations
print("Tokenizing prompts and running through model...")
for prompt in tqdm(prompts, desc="Tokenizing and running prompts"):
    # IMPORTANT TRUNCATE FOR VRAM SAFETY
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LEN_TRUNC).to(device)

    # Since this is the forward pass, we are interested in the activations of the layer are not doing backprop
    # We can disable gradient calculation to save memory
    with torch.no_grad():
        # Obliterate the current cache so we can only keep the activations we are interested in currently.
        stored_activations.clear()
        # Forward pass through the model
        model(**inputs)

        # Check if we have activations stored, from the forward hook from the forward pass.
        if stored_activations:
            # Grab middle token from first sequence
            activ = stored_activations[0][0]  # shape: (seq_len, hidden_dim)
            mid_token_index = activ.shape[0] // 2
            middle_activation = activ[mid_token_index]  # shape: (hidden_dim,)

            # Simple activation collector for later, check if it exists already, if not create it
            if 'all_activations' not in locals():
                all_activations = []
            all_activations.append(middle_activation)

print("Finished running prompts through model. Removing hook")
hook.remove()

Now that we have the activations we are interested in, we can now create and start to train the SAE. The following code is a simple implementation of the SAE. We can use the following relatively simple class definition with Pytorch to create the SAE. The class is a simple autoencoder with a linear encoder and decoder. The encoder takes the activations as input and outputs a lower-dimensional representation of the activations. The decoder takes this lower-dimensional representation and outputs the original activations.

In [None]:
# Fairly simplistic SAE implementation by deriving from base nn.Module
class SAE(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # Not too sure if linear transform is the right one to use here, will investigate later
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # ReLU is the basic/standard one I think, the variants use something different?
        z = torch.relu(self.encoder(x))

        # Math convention calls the decoded, x_hat
        x_hat = self.decoder(z)
        return x_hat, z

In [None]:
# Stack activations into a single tensor then shove it into the model
data = torch.stack(all_activations).to(device)
sae = SAE(input_dim=data.shape[1], hidden_dim=1024).to(device)
optimizer = optim.Adam(sae.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

Now that we have a configured SAE, we want to train it. The training process is relatively simple, we use the MSE loss function to calculate the reconstruction loss between the input and output of the SAE. We also add a sparsity loss term to encourage the SAE to learn a sparse representation of the data. The sparsity loss is simply the mean of the absolute values of the activations of the encoder. This encourages the SAE to learn a sparse representation of the data, which is what we want.

In [None]:
# Start training the SAE
for epoch in tqdm(range(SAE_TRAIN_EPOCHS), desc="Training model"):
    sae.train()
    optimizer.zero_grad()

    recon, codes = sae(data)
    recon_loss = loss_fn(recon, data)
    sparsity_loss = codes.abs().mean()
    loss = recon_loss + ENCOURAGEMENT_COEFF * sparsity_loss

    loss.backward()
    optimizer.step()

    # Track histories for very nice graph at the end
    loss_history.append(loss.item())
    recon_history.append(recon_loss.item())
    sparsity_history.append(sparsity_loss.item())

    # This is a massive slowdown but lowkey useful for debugging
    # print(f"Epoch {epoch}: Total={loss.item():.4f}, Recon={recon_loss.item():.4f}, Sparsity={sparsity_loss.item():.4f}")

## Part 1 Results

__NB. The following plots are fairly standard matplotlib boilerplate code so I won't explain the actual code too much, just the results.__

### Activations Histogram
This is the code activations histogram after training.
The goal for the SAE is to have a distribution that is fairly sparse, with a few activations as it should "disentangle" semantics and should be mono-semantic in nature. Each neuron should ideally represent a single concept. Hence, logically, the ideal result for neurons to mostly be off during interpretation. Hence, the result should massively skew to 0.

In [None]:
sae.eval()
_, final_codes = sae(data)
final_codes = final_codes.detach().cpu().numpy()

plt.figure()
plt.hist(final_codes.flatten(), bins=100, range=(0, 1), log=True)
plt.xlabel("Activation Value")
plt.ylabel("Frequency (log scale)")
plt.title("Histogram of SAE Code Activations")
plt.grid(True)
plt.show()

### The sparsity distribution per sample
This is the number of active units per sample and should be fairly "unique" in a sense that, there should be one massive spike and then significantly lower values surrounding it. There is should only be a small number of features that are active at any given time, indicating sparsity and that the SAE is working and learning a sparse representation of the data.

In [None]:
binary_codes = (final_codes > ENCOURAGEMENT_COEFF).astype(int)
sparsity_per_sample = binary_codes.sum(axis=1)

plt.figure()
plt.hist(sparsity_per_sample, bins=30)
plt.xlabel("Number of Active Units")
plt.ylabel("Number of Samples")
plt.title("Sparsity Distribution Per Sample")
plt.grid(True)
plt.show()

### Loss History
This is the loss history for the training process. The total loss should be decreasing over time, and the reconstruction loss should be significantly lower than the sparsity loss. This indicates that the SAE is learning a good representation of the data and that the sparsity loss is not dominating the training process. This is also just standard matplotlib boilerplate code so I won't explain the actual code too much.

In [None]:
plt.figure()
plt.plot(loss_history, label="Total Loss")
plt.plot(recon_history, label="Reconstruction Loss")
plt.plot(sparsity_history, label="Sparsity Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("SAE Training Losses")
plt.legend()
plt.grid(True)
plt.show()

It took me probably like 10-11 hours to get here, 80% of which was reading research and the anthropic articles. Next, I will create an SAE variant upon the previous logic/implementation.

# Part 2: SAE Variant

This is my SAE variant, I will be basing it on a XXXXXXXXXXX variant and this is the following code. Very little can be changed from the above globals and data, only a new class has to be created and the training loop has to be changed. The rest is the same.