# SAE Intro Project

The following project and notebook is a simple introduction to the Sparse Auto Encoder (SAE). I do apologise for any rushed or 'malpractice' code, I have finals rn haha.
Anyway, the project is split into two parts:
1. A simple SAE implementation from scratch using Pytorch.
2. A new SAE variant.


In [None]:
# imports and globals setup
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset

# The LLM from which we shall build this SAE from
model_name = "EleutherAI/pythia-70m-deduped"

# The dataset we shall use to prompt the SAE while training. Originallt I wanted to use "monology/pile-uncopyrighted"
# But this is too huge and I don't have enough space for this, so I will use a smaller dataset, which makes everything faster,
# plus this is a proof-of-concept and this can be easily scaled.
# dataset_name = "monology/pile-uncopyrighted"
dataset_name = "NeelNanda/pile-10k"

# Store activations here, from the layer which we are interested in
# Might be a better way to do this but this is the simplest
stored_activations = []

# This is the layer we are interested in, in the model.
# NB 0-based index, so 5 is the 6th layer
# Choice of layer is kinda arbitrary, however we want to choose something that is not too deep
# and not too shallow. This is a good trade-off in interpretability plus the Anthropic paper
# suggests that these "middle" layers are a decent choice to start.
chosen_layer = 5

# The maximum number of tokens the input is truncated to.
# This is important for VRAM safety, especially on smaller GPUs
MAX_LEN_TRUNC = 128

# This is the hook function, we shall call this during the forward pass
# So we can  "inspect" the activations of the layer
# by simply storing them in the global activations array
def hook_fn(module, input, output):
    stored_activations.append(output.detach().cpu())

## Part 1: Simple SAE Implementation

The following chunk of code is just pre-requisite code for setting up the SAE. This step 1.

In [None]:
# Load the dataset and training prompts
print("Loading dataset:", dataset_name)
dataset = load_dataset(dataset_name, split="train", keep_in_memory=True)
print("Dataset loaded")
prompts = [example["text"] for example in dataset]
print("Loaded dataset with", len(prompts), "prompts.")

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
print("Loaded model:", model_name)

# Move model to GPU if available, using MPS cause I have a Mac
# Hopefully this should auto-detect which device to use
device = torch.device("mps"
                        if torch.backends.mps.is_available() else "cuda"
                        if torch.cuda.is_available() else "cpu")
print("Using device:", device) # print for sanity

# Move model to device and set to eval mode
model.to(device).eval()
print("Set model to eval mode.")

# Choose a layer — here, MLP from block 5
target_layer = model.gpt_neox.layers[chosen_layer].mlp
hook = target_layer.register_forward_hook(hook_fn)
print("Registered forward hook on layer:", chosen_layer)

Next, we want to tokenize the prompts and run them through the model. Specifically, the goal of the next step
is to run the model in inference mode, and store the activations of the layer we are interested in. This is so that we can
later use these activations to train the SAE. The activations are stored in the `all_activations` variable. We grab this data from the kinda arbitrary layer we chose earlier.

In [None]:
# Loop through all the prompts and tokenize them, saving the required activations
for prompt in prompts:
    # IMPORTANT TRUNCATE FOR VRAM SAFETY
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LEN_TRUNC).to(device)

    # Since this is the forward pass, we are interested in the activations of the layer are not doing backprop
    # We can disable gradient calculation to save memory
    with torch.no_grad():
        # Obliterate the current cache so we can only keep the activations we are interested in currently.
        stored_activations.clear()
        # Forward pass through the model
        model(**inputs)

        # Check if we have activations stored, from the forward hook from the forward pass.
        if stored_activations:
            # Grab middle token from first sequence
            activ = stored_activations[0][0]  # shape: (seq_len, hidden_dim)
            mid_token_index = activ.shape[0] // 2
            middle_activation = activ[mid_token_index]  # shape: (hidden_dim,)

            # Simple activation collector for later, check if it exists already, if not create it
            if 'all_activations' not in locals():
                all_activations = []
            all_activations.append(middle_activation)

hook.remove()