# Gemma Scope Tutorial

This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a "microscope" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!

**Learn more:**
* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).
* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).
* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details



For illustrative purposes, we begin with a lightweight tutorial that uses as few libraries as possible to outline how Gemma Scope works, and what Sparse Autoencoders are doing. This is deliberately a fairly minimalist tutorial, designed to make clear what is actually going on, but does not model research best practices.

For any serious research with Gemma Scope, **we recommend using the [SAELens](https://jbloomaus.github.io/SAELens/) and [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) libraries**, see [this tutorial](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb) on how to use [SAELens](https://jbloomaus.github.io/SAELens/) in practice.


## Loading the Model

First, let's load the model:

For simplicity we do this straight from [HuggingFace transformers](https://huggingface.co/docs/transformers/en/index), rather than using an interpretability focused library like [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) or [nnsight](https://nnsight.net/)

In [1]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch

In [4]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

We load Gemma 2 2B, the smallest model that Gemma Scope works for. We load the base model, not the chat model, since that's where our SAEs are trained. Though the SAEs seem to transfer OK to these models. First, you'll need to authenticate with huggingface in order to download the model weights

In [5]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

In [8]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Now we've loaded the model, let's try running it! We give it the prompt "What is the Capital of India? How many states and what are the total languages are spoken in India?" and print the generated output

## Loading a Sparse Autoencoder

In [9]:
# The input text
prompt = "What is the Capital of India? How many states and what are the total languages are spoken in India?"

# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

OK, so we have got Gemma 2 loaded and can sample from it to get sensible stuff. Now, let's load one of our SAEs.

GemmaScope actually contains over four hundred SAEs, but for now we'll just load one on the residual stream at the end of layer 20 (of 26, note that layers start at 0 so this is the 21st layer. This is a fairly late layer, so the model should have time to find more abstract concepts!).

In [10]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [11]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}

In [12]:
{k:v.shape for k, v in pt_params.items()}

{'W_dec': torch.Size([16384, 2304]),
 'W_enc': torch.Size([2304, 16384]),
 'b_dec': torch.Size([2304]),
 'b_enc': torch.Size([16384]),
 'threshold': torch.Size([16384])}

In [13]:
pt_params["W_enc"].norm(dim=0)

tensor([1.2101, 1.1695, 0.9836,  ..., 1.0630, 0.9997, 1.1070], device='cuda:0')

### Implementing the SAE

We now define the forward pass of the SAE for pedagogical purposes (in practice, we recommend using the implementation in SAELens)

Gemma Scope is a collection of [JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like a standard two layer (one hidden layer) neural network, but where the activation function is a **JumpReLU**: a ReLU with a discontinuous jump.

In [14]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [15]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

### Running the SAE on model activatinos


Let's first get out some activations from the model at the SAE target site. We'll demonstrate how to do this 'manually' first, by using Pytorch hooks. Note that this is not particularly good practice, and it's probably more practical to use a library like TransformerLens to handle hooking the SAE into a model forward pass. But for illustrative purposes, it's useful to see how it's done.

We can gather activations at a site by registering a hook. To keep this local, we can wrap this in a function that registers a hook, runs the model, saving the intermediate activation, then removes the hook. (This is basically what TransformerLens is doing under the hood)

In [16]:
def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [17]:
target_act = gather_residual_activations(model, 20, inputs)

In [18]:
sae.cuda()

JumpReLUSAE()

In [19]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [20]:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

tensor(0.8665, device='cuda:0')

In [21]:
(sae_acts > 1).sum(-1)

tensor([[7017,   41,   64,  154,   67,   43,   66,   65,   64,   92,  101,   88,
           85,   80,   73,   60,  103,   97,   85,   71,   74,   77]],
       device='cuda:0')

It's always worth checking this sort of thing when you do this by hand to check that you haven't got the wrong site, or are missing a scaling factor or something like this. But here, our results all look like they are supposed to .

Note that there's a bit of a gotcha here; our SAEs are *NOT* trained on the BOS token, because we found that this tended to be a large outlier and to mess up training. So they tend to give nonsense when we apply to them to it, and we need to be careful not to do this accidentally! We can see this above : the BOS token is a total outlier in terms of L0!

Let's look at the highest activating features on this input text, on each token position:

In [22]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631, 11527,   994,  8684,  6470,   247, 12935,  8284,  2538,  8284,
          4223,  8284,  8112, 15074,  8284, 13131,  3514, 12564, 12935,  8284,
         12935,  4223]], device='cuda:0')

So we see that one of the max activating examples on this question is [SAE feature 10004](https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/10004), which fires on concepts related to time travel! We can visualise this below in the notebook, embedding the neuronpedia dashboard in the colab cell:


In [25]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)

### SAELens

We recommend using SAELens: https://github.com/jbloomAus/SAELens for research on SAEs. To load this SAE in SAELens, run the following:

In [24]:
!pip install sae-lens

from sae_lens import SAE  # pip install sae-lens

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = "layer_20/width_16k/canonical",
)


Full SAELens documentation, tutorials, etc. can be found at https://github.com/jbloomAus/SAELens