In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

# from steering.eval_utils import evaluate_completions
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f655ef0a020>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding

In [6]:
unsteered_texts = generate(model,
        hooks=[],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1024,
        )

In [7]:

wedding_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=wedding, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1024,
        )

In [8]:

london_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=london, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1024,
        )

In [13]:
london_texts[:100]

['I think with this London, the tube stations and underground stations will be very very busy during the Olympics event in 2012',
 "I think we can all agree that <em>Downton</em>'s second series is a bit patchy. The plot, and the character",
 'I think it needs to be London.\n\nLondon Bridge\n\nGreat, you’ve seen most of the top places to see in the',
 'I think the term I’m looking for is “post-it”. London is a fantastic city with so many museums, art galleries',
 'I think he is one of the best doctors to give his services for the people, also on the other side He is really good to',
 'I think the following is what you mean by the graph of an exponential function\n\n.25e^7, ,-34',
 'I think my mother would have preferred to live with my parents. But at twenty I wasn’t the perfect girl but she wasn’',
 "I think it's all a bit bonkers this week. Soaring gas prices are still with us, but I' has been",
 "I think this is my best work so far! (I'm so glad I saw the other reviews saying if you

In [None]:
@torch.no_grad()
def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    _, acts = model.run_with_cache(text, names_filter=hook_point)
    acts = acts[hook_point]

    all_sae_acts = []
    for batch in acts:
        sae_acts = sae(batch).feature_acts
        all_sae_acts.append(sae_acts)