**A brief tour of some of the core methods around feature steering**

# Setup

In [1]:
!pip install -q git+https://github.com/jbloomAus/SAELens
%pip install transformer_lens
%pip install datasets
%pip install tqdm
%pip install huggingface_hub

[31mERROR: Could not find a version that satisfies the requirement sae-vis<0.3.0,>=0.2.18 (from sae-lens) (from versions: 0.2, 0.2.1, 0.2.2, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.2.9, 0.2.10, 0.2.11, 0.2.12, 0.2.13, 0.2.14, 0.2.15)[0m[31m
[0m[31mERROR: No matching distribution found for sae-vis<0.3.0,>=0.2.18[0m[31m
You should consider upgrading via the '/usr/local/opt/python@3.9/bin/python3.9 -m pip install --upgrade pip' command.[0m[33m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))

from typing import Callable, Optional

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## To access the Gemma model, need an access token (at least for collab)
from huggingface_hub import login
login(token="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")

torch.set_grad_enabled(False)

ImportError: cannot import name 'SAE' from 'sae_lens' (/usr/local/lib/python3.11/site-packages/sae_lens/__init__.py)

In [None]:
# Load in the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

In [None]:
@torch.no_grad()
def normalise_decoder(sae, scale_input=False):
    """
    Normalises the decoder weights of the SAE to have unit norm.

    Use this when loading for gemma-2b saes.

    Args:
        sae (SparseAutoencoder): The sparse autoencoder.
        scale_input (bool): Use this when loading layer 12 model.
    """
    norms = torch.norm(sae.W_dec, dim=1)
    sae.W_dec /= norms[:, None]
    sae.W_enc *= norms[None, :]
    sae.b_enc *= norms

Note: Somewhat soon, we want to be working with SAEs from multiple layers of Gemma 2b. Google Deepmind will be releasing a full set by approx July 14th, 2024.

In [None]:
# Loading in layer 6 SAE, there is also layer 11 sae and soon all layers
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

# Getting activations

In [None]:
@torch.no_grad()
def text_to_sae_feats(
        model: HookedTransformer,
        sae: SAE,
        hook_point: str,
        text: str,
        return_logits = False,
    ):
    """
    Converts text to SAE features.

    Returns:
        torch.Tensor: SAE activations. Shape: [batch_size, sequence_len, d_sae]
    """
    _, acts = model.run_with_cache(text, names_filter=hook_point, prepend_bos=True)

    acts = acts[hook_point]

    all_sae_acts = []
    for batch in acts:
        sae_acts = sae.encode(batch)
        all_sae_acts.append(sae_acts)

    return torch.stack(all_sae_acts, dim=0)


@torch.no_grad()
def top_activations(activations: torch.Tensor, top_k: int=10):
    """
    Get the top_k activations for each position in the sequence.

    Returns:
        top_v (torch.Tensor): Top k activations. Shape: [batch_size, sequence_len, top_k]
        top_i (torch.Tensor): Top k indices. Shape: [batch_size, sequence_len, top_k]
    """
    top_v, top_i = torch.topk(activations, top_k, dim=-1)

    return top_v, top_i

Example usage

`text_to_sae_feats` get raw feature activations for all token positions

`top_activations` gets list of features sorted by score

In [None]:
sae_acts = text_to_sae_feats(model, sae6, hp6, "The quick brown fox jumps over the lazy dog.")
print(sae_acts)

top_v, top_i = top_activations(sae_acts)

print(top_i)

In [None]:
import torch.nn.functional as F

# Normalize the embeddings along the embedding dimension
normalized_embeddings = F.normalize(sae6.W_dec, p=2, dim=1)

# Calculate the cosine similarity matrix
cos_sim_matrix = torch.mm(normalized_embeddings, normalized_embeddings.t()).cuda()

def find_similar_features(target_idx, top_k=10, return_rows=True):
    """
    Finds the top-k most similar rows (or columns) to a target row (or column) in a cosine similarity matrix.

    Args:
        cos_sim_matrix (torch.Tensor): The cosine similarity matrix.
        target_idx (int): The index of the target row (or column) to compare against.
        top_k (int, optional): The number of most similar rows (or columns) to return. Default is 10.
        return_rows (bool, optional): If True, return the most similar rows. If False, return the most similar columns. Default is True.

    Returns:
        torch.Tensor: A tensor containing the indices of the top-k most similar rows (or columns).
        torch.Tensor: A tensor containing the similarity scores of the top-k most similar rows (or columns).
    """
    target_vector = cos_sim_matrix[target_idx] if return_rows else cos_sim_matrix[:, target_idx]
    similarities = target_vector if return_rows else target_vector.T

    topk_similarities, topk_indices = torch.topk(similarities, k=top_k, largest=True, sorted=True)

    if not return_rows:
        topk_indices = topk_indices.T

    return topk_indices, topk_similarities

In [None]:
topk_indices, topk_similarities = find_similar_features(10351) # 1062 is the feature for Anger

print(topk_indices)


# # If you want to calculate all cosine similarities at once and save to a file
# indices = []
# similarities = []

# for i in range(16384):
#     topk_indices, topk_similarities = find_similar_features(i)
#     indices.append(topk_indices)
#     similarities.append(topk_similarities)

# # Convert lists of tensors to 2D tensors
# indices_tensor = torch.stack(indices)
# similarities_tensor = torch.stack(similarities)

# # Save topk_indices to a .pt file
# torch.save(indices_tensor, 'cosine_sim_indices.pt')

# # Save topk_similarities to a .pt file
# torch.save(similarities_tensor, 'cosine_sim_values.pt')



When doing cosine similarity for feature `1058` (anger), you find feature `15989` which is an abusive language feature. (https://www.neuronpedia.org/gemma-2b/6-res-jb/15989)

`10351` (intelligence) finds `5827` (smart), `4982` ([adept, skilled](https://www.neuronpedia.org/gemma-2b/6-res-jb/4982))

Sometimes cosine similarity can return other similar feature but this method is hit or miss. For example, finding additional wedding features doesn't work as well.

In [None]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

# Steering

In [None]:
# This is the simplest patching method
# You insert the steering vector at all token positions, for all batches
def patch_resid(resid, hook, steering, scale=1):
    resid[:, :, :] += steering * scale
    return resid


def mix_vectors(feature_pairs):
    # Initialize lists to store encoded and decoded feature vectors
    steer_dec_list = []

    # Iterate over each feature-strength pair
    for feature, strength in feature_pairs:
        # Calculate the steered encoded and decoded vectors
        feature_steer = sae6.W_dec[feature, :] * strength

        print(feature_steer.shape)
        print(feature_steer)

        # Append the vectors to the respective lists
        steer_dec_list.append(feature_steer)

    # Stack the lists to create the final tensors
    steer_dec = torch.stack(steer_dec_list, dim=0)

    return steer_dec.sum(dim=0)


@torch.no_grad()
def generate_basic(
    model: HookedTransformer,
    steer: tuple[str, Callable], # includes the steering hook.
    prompt = "",
    n_samples=4,
    batch_size=4,
    max_new_tokens=40,
    top_k=50,
    top_p=0.3,
):
    tokens = model.to_tokens(prompt, prepend_bos=True)
    prompt_batch = tokens.expand(batch_size, -1)

    results = []
    num_batches = (n_samples + batch_size - 1) // batch_size  # Calculate number of batches

    with model.hooks(fwd_hooks=[steer]):
        for _ in tqdm(range(num_batches)):
            batch_results = model.generate(
                prompt_batch,
                prepend_bos=True,
                use_past_kv_cache=True,
                max_new_tokens=max_new_tokens,
                verbose=False,
                top_k=top_k,
                top_p=top_p,
            )
            batch_results = batch_results[:, 1:] # cut bos
            str_results = model.to_string(batch_results)
            results.extend(str_results)
    return results[:n_samples]

@torch.no_grad()
def generate(
    model: HookedTransformer,
    hooks: list[tuple[str, Callable]], # includes the steering hook.
    schedules: Optional[list[tuple[int, int]]] = None,
    prompt = "",
    n_samples=4,
    batch_size=4,
    max_new_tokens=40,
    top_k=50,
    top_p=0.3,
):
    if schedules is None:
        schedules = [(None, None) for _ in hooks]

    token_step = 0

    def counter_hook(resid, hook):
        # keeps track of which token we're up to
        nonlocal token_step
        token_step += 1
        return resid

    def updated_hook(resid, hook, hook_fn, start, end):
        nonlocal token_step
        if token_step >= start and token_step <= end:
            return hook_fn(resid, hook)
        return resid

    new_hooks = []
    for i, (hook_layer, hook_fn) in enumerate(hooks):
        # we modify every hook_fn to only run when the token_step is within the schedule
        start, end = schedules[i]
        start = start if start is not None else 0
        end = end if end is not None else max_new_tokens + 1
        new_hooks.append((hook_layer, partial(updated_hook, start=start, end=end, hook_fn=hook_fn)))
    new_hooks.append(("blocks.0.hook_resid_post", counter_hook))

    tokens = model.to_tokens(prompt, prepend_bos=True)
    prompt_batch = tokens.expand(batch_size, -1)

    results = []
    num_batches = (n_samples + batch_size - 1) // batch_size  # Calculate number of batches

    with model.hooks(fwd_hooks=new_hooks):
        for _ in range(num_batches):
            batch_results = model.generate(
                prompt_batch,
                prepend_bos=True,
                use_past_kv_cache=True,
                max_new_tokens=max_new_tokens,
                verbose=False,
                top_k=top_k,
                top_p=top_p,
            )
            batch_results = batch_results[:, 1:] # cut bos
            str_results = model.to_string(batch_results)
            results.extend(str_results)
    return results[:n_samples]

The following two examples demonstrate steering using the wedding feature we found in the first section scaled to 60. `generate_basic` and `generate` do the same thing in these two examples.

However, we are working on improving the `generate` to allow for steering on multiple features over time. For example, if you first want to steer in one direction for X steps and then steer in another.

In [None]:
generate_basic(
    model,
    (hp6, partial(patch_resid, steering=wedding, scale=60)),
    "I think",
    batch_size=64,
    n_samples=5,
)

In [None]:
generate(
    model,
    hooks=[(hp6, partial(patch_resid, steering=wedding, scale=60))],
    prompt="I think",
    batch_size=64,
    n_samples=5,
)

**Multi-feature steering**

Pass the features you want to steer with along with the scale factor for how strongly to steer in that direction

Optimal scale factors for a variety of feature combinations were studied over a one week period and documented in this figma: https://www.figma.com/design/7csDOF7mDg1OiF6Na7dfin/Matt-%2B-Slava-Research-Report?node-id=0-1&t=rFDXzEIFcRQK3DBL-1

There are some interesting implications of this research. Namely, that at any one time, it is unlikely that you can steer with more than two to three features at once and retain model coherence using our current method. This is because when the total added scale factor is greater than ~80, the model loses coherence. However, for a feature to have a chance to be expressed, it needs to be inserted at a minimum scale factor of 30-40. Hence, max 2-3 features inserted at once.

It's worth noting here that even at the optimal scale factors, the output still has a high variance, sometimes not reflecting both features. This is something we are actively trying to improve – for both single and multi-feature steering.

In [None]:
# single feature steering
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  #  wedding

# mixed feature steering
london_wedding = mix_vectors([[10138, 40], [8406, 40]])

In [None]:
# insert at scale factor 1 since the scale factor was added in during mixing
generate(
    model,
    hooks=[(hp6, partial(patch_resid, steering=london_wedding, scale=1))],
    prompt="I think",
    batch_size=64,
    n_samples=10,
)

**Scheduling**

Another goal for steering is the ability to steer on multiple directions in sequence.

Our `generate` function allows for this kind of steering by specifying for which range of tokens a particular steering vector should have influence.

To do this, you can pass in multiple hooks along with a list of schedules.

Schedules indicate how long each hook should be steering based on token count. A `None` parameter on the beginning or end is like extending the influence of that steering hook to the very start or end.

Due to not resetting the `kv_cache` in the `generate` function (which is how production systems would operate), early steering vectors continue to influence later text, making it harder for later steering directions to have an effect. We are working on improving this while keeping generation efficient.

In [None]:
hooks = [
    (hp6, partial(patch_resid, steering=london, scale=70)),
    (hp6, partial(patch_resid, steering=wedding, scale=70)),
]

generate(model,
        hooks=hooks,
        schedules=[(1, 15), (16, None)],
        max_new_tokens=50,
        prompt="I think",
         )

# Evals

We have developed, and are developing, a number of eval methods which aim to complement each other.

TODO: document here