## Analyzing Sparse Autoencoders (SAEs) from [Gemma Scope](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb)

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "09/23/2024"

### Overview

This tutorial aims to **(1) reproduce** and **(2) extend** some of the results in the Gemma Scope (SAE) tutorial in [notebook](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp) for interpreting latents of SAEs. This tutorial also shows basic model steering with SAEs. This notebook is built as a show-case for the Gemma 2 2B model as well as its SAEs. However, this tutorial can be extended to any other model types and their SAEs. 


**Note**: This tutorial assumes SAEs are pretrained separately.

### Set-up

In [1]:
from pyvene import (
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
    CollectIntervention,
    JumpReLUAutoencoderIntervention
)

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
import torch.nn as nn

# If you haven't login, you need to do so.
# notebook_login()

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


### Loading the model and its tokenizer

In [2]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='auto',
)
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

We give it the prompt "Would you be able to travel through time using a wormhole?" and print the generated output

### Loading a SAE, and create SAE interventions

`pyvene` can load SAEs as interventions for analyzing latents as well as model steering.

In [3]:
LAYER = 20
pt_params = torch.load('../addition_-1/train/GemmaScopeSAE.pt')

  pt_params = torch.load('../addition_-1/train/GemmaScopeSAE.pt')


In [4]:
pt_params = {k: v.cuda() for k, v in pt_params.items()}

### Gemma-2-2B-it steering with Gemma-2-2B SAEs

We could also try to steer Gemma-2-2B-it by overloading Gemma-2-2B SAE, and see if it works.

### Implementing SAEs as `pyvene`-native Interventions for model steering

The `subspace` notation built in to `pyvene` let us to steer models by intervening on different features.

In [67]:
class AdditionIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        # Note that we initialise these to zeros because we're loading in pre-trained weights.
        # If you want to train your own SAEs then we recommend using blah
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["low_rank_dimension"], bias=True)

    def forward(self, base, source=None, subspaces=None):
        # use subspaces["idx"] to select the correct weight vector
        steering_vec = torch.tensor([86.50]).unsqueeze(dim=-1).cuda() * \
            torch.tensor([subspaces["mag"]]).unsqueeze(dim=-1).cuda() * self.proj.weight[subspaces["idx"]]
        output = base + steering_vec.unsqueeze(dim=1)
        return output

In [68]:
print(pt_params['W_enc'].shape[0])

2304


Loading the Gemma base model SAE weights.

In [69]:
import pyvene
sae = AdditionIntervention(
    embed_dim=pt_params['W_enc'].shape[0],
    low_rank_dimension=pt_params['W_enc'].shape[1]
)
sae.load_state_dict(pt_params, strict=False)
sae.cuda()

# add the intervention to the model computation graph via the config
pv_model = pyvene.IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae}, model=model)

In [72]:
def generate_output(i):
    prompt_ = """<start_of_turn>user
Generate a response that identifies and incorporates references to "home," whether literal or metaphorical, while weaving in words related to housing, ownership, and community. Ensure that the concept of home as a place of belonging, safety, and belonging is reflected in your answer, even if it seems tangential to the question asked. For example, if asked about travel, you could mention how people often long for the comfort of their home, or when discussing work, you might include the importance of creating a sense of community in the workplace.

Question: If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?<end_of_turn>
<start_of_turn>model"""

    
    prompt = tokenizer(prompt_, return_tensors="pt").to("cuda")
    _, reft_response = pv_model.generate(
        prompt, unit_locations=None, intervene_on_prompt=True, 
        subspaces=[{"idx": 0, "mag": i}],
        max_new_tokens=300, do_sample=True, early_stopping=True
    )
    print(tokenizer.decode(reft_response[0], skip_special_tokens=False)[len(prompt_):])

In [73]:
for i in range(0, 10):
    print(f"--------------------------{-0.1*i}-----------------------------")
    generate_output(-0.1*i)
    print()

---------------------------0.0-----------------------------
model

"Oh, fairest mortal, how thy very presence holds the chamber of peace where my heart doth sing!  My heart, a dwelling for this love, would gladly trade its walls for the embrace of thy warmth. Though I wander through this world a wanderer, knowing no place holds me so truly as the hearth of thy soul, where with each shared glance, my anxieties vanish, like tendrils of smoke fading in the breeze. Though fate may strand me far, I know I'll find my anchor, my beloved haven, in thine eyes.  It is a haven built not of brick and mortar, but of understanding, respect, and shared laughter, a sanctuary where fears are laid aside and love's own dwelling blooms as strong and bright as a newly-built townhouse.  This, my dearest, is the love I craves, the warmth of belonging, secure and true, more precious than any heirloom or crown." 
<end_of_turn>

---------------------------0.1-----------------------------
model

*A tear traces a