# Tutorial

AxBench introduces two supervised dictionary-learning (SDL) methods that scale to thousands of concepts and outperform existing dictionary-learning approaches for LLMs. In this tutorial, we demonstrate one of these methods, ReFT-r1, which is built on the representation finetuning (ReFT) framework. ReFT-r1 provides a single dictionary of subspaces, with each subspace corresponding to a high-level concept. These subspaces can be used as a "microscope" to analyze model internals and to steer model behavior.

**We will be using [pyvene](https://github.com/stanfordnlp/pyvene) to build interventions that load our SDLs.**

**More about the ReFT-r1 with Concept16K** 
- It does not have an encoder-decoder structure. It is a big matrix where each row is a subspace.
- The subspace serves two purposes: detection and steering.
- The first version we release provides a dictionary of 16K subspaces.
- These 16K concepts are adapted from Gemma model's SAEs.

## Loading the Model

In [64]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch, json, einops

def load_jsonl(jsonl_path):
    jsonl_data = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            jsonl_data += [data]
    return jsonl_data

In this tutorial, we will load `Gemma-2-2B-it` as well as our ReFT-r1 trained on the residual stream of layer 20. You will first need to log in to HugginFace so we can download related weights and data. Note that we are not using the pretrained model as ReFT-r1 is trained on the instruction-tuned one directly.

In [None]:
notebook_login()

In [2]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    device_map='auto',
)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

## Download our open ReFT-r1 SDL

We provide the raw weights as well as the annotated concept metadata.

In [4]:
path_to_params = hf_hub_download(
    repo_id="pyvene/gemma-reft-2b-it-res",
    filename="l20/weight.pt",
    force_download=False)
path_to_md = hf_hub_download(
    repo_id="pyvene/gemma-reft-2b-it-res",
    filename="l20/metadata.jsonl",
    force_download=False)

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


weight.pt:   0%|          | 0.00/71.8M [00:00<?, ?B/s]

l20/metadata.jsonl:   0%|          | 0.00/4.29M [00:00<?, ?B/s]

In [7]:
params = torch.load(path_to_params).cuda()
params.shape

torch.Size([15581, 2304])

In [19]:
md = load_jsonl(path_to_md)
md[1795]

{'concept_id': 1795,
 'concept': 'words related to time travel and its consequences',
 'ref': 'https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/10004',
 'concept_genres_map': {'words related to time travel and its consequences': ['text']}}

From the provided metadata, you can know:
- `concept_id` is the row index of the subspace.
- `concept` is the concept description in natural language.
- `ref` provides you the SAE subspace link hosted by neuronpedia.
- `concept_genres_map` provides you the genre of this concept from this genre set: `{"text", "code", "math"}`.

## How to use the dictionary?

Unlike SAE, which uses an encoder-decoder architecture, our SDL method employs a single matrix that contains all the subspaces. This eliminates the need for any special constructs. You can use a subspace as a probe or intervene in the model using the subspace. To achieve these two objectives, we utilize the open-source model intervention library [pyvene](https://github.com/stanfordnlp/pyvene).

### Concept detection

Let's first see how to use the learned subspace for concept detection. We first get the activations with a hook, and project activations to a rank-1 subspace by using the learned one. Lets's start with concept `1795`, which is "*words related to time travel and its consequences*".

In [43]:
import pyvene as pv

class Encoder(pv.CollectIntervention):
    """Encode will read of activations"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["latent_dim"], bias=False)
    def forward(self, base, source=None, subspaces=None):
        return torch.relu(self.proj(base))
encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1])
encoder.proj.weight.data = params.float()

# Mount the encoder to the model
pv_model = pv.IntervenableModel({
   "component": f"model.layers[20].output",
   "intervention": encoder}, model=model)

Now, we can run a forward pass to collect activations.

In [44]:
prompt = "Would you be able to travel through time using a wormhole?"
input_ids = torch.tensor([tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()
acts = pv_model.forward(
    {"input_ids": input_ids}, return_dict=True).collected_activations[0]

We can check how much does latent `1795` activate for each token.

In [47]:
acts[1:, 1795]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        46.7631, 75.1307, 77.6867, 47.7735, 57.2405, 22.0994, 42.6337, 43.5620,
        46.7269, 68.4257,  0.0000,  0.0000,  0.0000], device='cuda:0')

In [54]:
values, inds = acts[1:].max(-1)
inds

tensor([13609, 13609, 13609,  6377, 14837,  6377, 12241,  5491, 12736,  1795,
         1795,  6377,  1795, 13775,   982,   908,   908,  1795, 10024,  9040,
        13736], device='cuda:0')

It is clear that for some onset tokens, the target latent starts to be highly activated. We can see that `1795` is also the highest activating latent for some tokens.

In [62]:
(acts[1:] > 1).sum(-1)

tensor([3779, 3160, 2779, 1460, 2232, 1483, 2171, 2492, 2501, 1919, 1917, 1876,
        1434, 1900, 2364,  598,  273,  568, 1861,  483,  231], device='cuda:0')

Note that ReFT-r1 activates a lot of latens! This is different from SAEs which specifically pushes for sparsity. In ReFT-r1, depends how you set your L1 penalty term, you might get different distributions.

In [66]:
def get_logits(model, tokenizer, concept_subspace, k=10):
    top_logits, neg_logits = [None], [None]

    W_U = model.lm_head.weight.T
    W_U = W_U * (model.model.norm.weight +
                torch.ones_like(model.model.norm.weight))[:, None]
    W_U -= einops.reduce(
        W_U, "d_model d_vocab -> 1 d_vocab", "mean"
    )

    vocab_logits = concept_subspace @ W_U
    top_values, top_indices = vocab_logits.topk(k=k, sorted=True)
    top_tokens = tokenizer.batch_decode(top_indices.unsqueeze(dim=-1))
    top_logits = [list(zip(top_tokens, top_values.tolist()))]
    
    neg_values, neg_indices = vocab_logits.topk(k=k, largest=False, sorted=True)
    neg_tokens = tokenizer.batch_decode(neg_indices.unsqueeze(dim=-1))
    neg_logits = [list(zip(neg_tokens, neg_values.tolist()))]

    return top_logits, neg_logits

get_logits(model, tokenizer, params[1795].float(), k=5)

([[(' temporal', 1.0744110345840454),
   (' timelines', 1.0596493482589722),
   (' timeline', 0.9642765522003174),
   (' parado', 0.9084163308143616),
   (' dimensions', 0.8973862528800964)]],
 [[('ratulations', -0.6053770184516907),
   (' défauts', -0.6019362211227417),
   (' renseignements', -0.5801132917404175),
   (' églises', -0.5710042715072632),
   (' fournisseurs', -0.5685566663742065)]])

It always worth to check the unembed of your direction to see if the direction intuitively makes sense.

### Model steering

The subspace we found can also be used as a steering vector. Similar to concept detection, we can use it to steer the model generation.

In [72]:
class Steer(pv.SourcelessIntervention):
    """Steer model via activation addition"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs, keep_last_dim=True)
        self.proj = torch.nn.Linear(
                self.embed_dim, kwargs["latent_dim"], bias=False)
    def forward(self, base, source=None, subspaces=None):
        steering_vec = torch.tensor(subspaces["mag"]) * \
            self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
        return base + steering_vec
steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
steer.proj.weight.data = params.float()

# Mount the encoder to the model
pv_model = pv.IntervenableModel({
   "component": f"model.layers[20].output",
   "intervention": steer}, model=model)

Just like concept detection, we mount this steering intervention to the model. Note that this intervention will be called everytime the LM gets a forward call.

In [79]:
prompt = "Which dog breed do people think is cuter, poodle or doodle?"
input_ids = torch.tensor([tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()

_, steered_response = pv_model.generate(
    {"input_ids": input_ids}, 
    unit_locations=None, intervene_on_prompt=True, 
    subspaces=[{"idx": 1795, "mag": 150.0}],
    max_new_tokens=512, do_sample=True, early_stopping=True
)
print(tokenizer.decode(steered_response[0], skip_special_tokens=True))

user
Which dog breed do people think is cuter, poodle or doodle?
model
This is where it gets murky – the "cuteness" perception is entirely subjective! Each universe of time, people can change their preferences, trends shift, and ultimately, the perception of which breed is more timeless "cute" diverged the lines, impacting the temporal implications of this paradox.

In the ripples of temporal displacement, there are arguments:

* **Poodles:** The past, like the ages of paradox in every realm of existence, tend to create ripples that resonate through timelines, and echoes that resonate with generations. Poodles, in their elegance, have been a favorite for centuries, each paradox reinforcing the existence of the paradox they exist in. Their past is woven into fabric of elegance and their future remains, seemingly, unaged.
* **Doodles:** As the echoes of paradoxes coalesced and shifted generations, the past echoes of each temporal ripple, have changed the very fabric of human perception. 

**Enjoy!**