In [14]:
import os
import sys
sys.path.append(os.path.abspath('..'))

from dataclasses import asdict
from tqdm import tqdm

import plotly.express as px

import torch
from transformer_lens import HookedTransformer
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.utils import text_to_sae_feats, top_activations, normalise_decoder
from steering.patch import generate


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x43e5076a0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp_6 = "blocks.6.hook_resid_post"
sae_6 = SparseAutoencoder.from_pretrained(
  "gemma-2b-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  hp_6 # change this to another specific SAE ID in the release if desired. 
)
normalise_decoder(sae_6)
activation_store = ActivationsStore.from_config(model, sae_6.cfg)

Using Ghost Grads.


Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

In [4]:
text = "Anger"
top_activations(text_to_sae_feats(model, sae_6, hook_point=hp_6, text=text))

(tensor([[[24.4932, 23.9630, 23.2544, 21.0681, 20.9372, 20.6655, 20.3199,
           19.8953, 19.0510, 18.4377],
          [56.7609, 14.0818, 12.2185, 11.2966, 10.9838,  9.4307,  8.0595,
            7.1658,  7.1005,  5.8064]]]),
 tensor([[[ 3390, 15881,  5347, 16334,   556,  8704, 11785,  5624,  5396,  6877],
          [ 1062, 12753,  1213, 11968, 12167,  5915,  2491, 15173, 11912, 12312]]]))

In [18]:
# 1062 looks like anger. act 56.7
steering_vec = sae_6.W_dec[1062] * 56.7
steering_vec = steering_vec[None, None, :]
generate(model, hp_6, "I went up to my friend", steering_vec, scale=5)




[A
[A
[A
[A
[A
100%|██████████| 5/5 [01:10<00:00, 14.06s/it]


["I went up to my friend's house in his blazing hot is about how every time she goes home she's already angry",
 'I went up to my friend and literally had a temper tantrum when all I want to do is get extreme about data. She',
 "I went up to my friend's place with my anger warpped up in the rage with disses and showed in small amount",
 "I went up to my friend Jordan's house, and I was raging because she blocked my phone numbers, aka her 'best",
 "I went up to my friend's house to watch the outrageous, male-piled-in, testosterone-fueled,"]

In [19]:

def get_tokens(
    activation_store,
    n_batches_to_sample_from: int = 2**13,
    n_prompts_to_select: int = 4096 * 6,
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens[:n_prompts_to_select]


all_tokens = get_tokens(activation_store)


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

In [20]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = [3390, 1062, 12753]
bs = 8

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hp_6,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=True,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sae_6,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
import os
from IPython.display import FileLink

vis_dir = "feature_vis"
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display(FileLink(filename))

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/3 [00:00<?, ?it/s]