In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x10aea2a10>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device='cpu')

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"
hp12 = "blocks.12.hook_resid_post"
sae12 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp12)
normalise_decoder(sae12, scale_input=False)

In [4]:
top_activations(text_to_sae_feats(model, sae12, hook_point=hp12, text="I'm really fucking angry"))
# top anger is 12312 with act 13.18
# 12312
# 15892

(tensor([[[119.2641, 112.2287, 110.4388, 103.3638,  65.7501,  48.8222,  38.8005,
            31.0149,  23.6835,  22.8484],
          [ 57.9476,  44.9053,  29.9884,  20.9755,  12.8985,  12.0612,   9.9529,
             6.9970,   6.4516,   5.1988],
          [ 31.8904,  30.7403,   9.3134,   9.2018,   7.9432,   6.5424,   6.4582,
             4.8173,   4.7244,   3.5317],
          [ 85.6133,  19.6902,  17.8775,  14.6812,  13.6127,  10.5035,   9.7475,
             9.7319,   7.7390,   7.4115],
          [ 66.4443,  19.0343,  17.4837,  16.6859,  13.0240,  11.5825,  11.1650,
             8.5240,   8.4179,   6.7918],
          [ 51.7952,  43.6963,  18.8601,  12.3178,  10.3901,   8.6044,   7.5259,
             7.2767,   7.1782,   6.8423],
          [ 62.2843,  20.0152,  13.4350,  11.7916,  11.5234,  10.4116,   8.4978,
             8.2190,   7.9375,   7.3865]]]),
 tensor([[[ 8731, 10841,  5233,  3048,  1645, 11597,  6099,  7650,  6802,  8618],
          [ 5611, 13161, 12885,  4966, 10323, 12492,  

In [5]:
top_activations(text_to_sae_feats(model, sae12, hook_point=hp12, text="Hate"))
# top hate is 4823 with act 15.4

(tensor([[[119.2640, 112.2287, 110.4388, 103.3638,  65.7501,  48.8222,  38.8005,
            31.0149,  23.6835,  22.8484],
          [ 71.4620,  13.4814,  11.1678,  10.1321,   8.4601,   7.9647,   7.3423,
             6.8351,   6.1185,   5.8158]]]),
 tensor([[[ 8731, 10841,  5233,  3048,  1645, 11597,  6099,  7650,  6802,  8618],
          [ 4823, 10323,  8315,  5611,   883,  3081,  1997, 14039, 13485,  2061]]]))

In [6]:
top_activations(text_to_sae_feats(model, sae12, hook_point=hp12, text="I talk about weddings"))

(tensor([[[119.2641, 112.2287, 110.4388, 103.3638,  65.7501,  48.8222,  38.8005,
            31.0149,  23.6835,  22.8484],
          [ 57.9476,  44.9053,  29.9884,  20.9755,  12.8985,  12.0612,   9.9529,
             6.9970,   6.4516,   5.1988],
          [ 51.2428,  23.8737,  22.4954,  20.6098,  17.6660,  12.5602,   8.0263,
             7.0386,   5.9101,   5.4821],
          [ 34.2496,  18.4472,  14.6557,  12.4298,  11.5689,  10.6247,   9.6920,
             8.0853,   7.9613,   7.1294],
          [ 75.1981,  15.0948,  14.2738,  10.9615,   8.7597,   8.3244,   8.0689,
             7.7861,   7.7826,   7.1471]]]),
 tensor([[[ 8731, 10841,  5233,  3048,  1645, 11597,  6099,  7650,  6802,  8618],
          [ 5611, 13161, 12885,  4966, 10323, 12492,  7971,  3998, 10662, 13693],
          [ 4760, 15364,  7005,  7503, 10323, 14664, 15468,  4966,  8706,  7971],
          [ 6353, 15978, 10323, 15364,  6723,  3998, 14664,  7971,  2847, 16257],
          [ 9099,  9549,  9331, 10323,  8706, 10108,  

## Feature Vis

In [7]:
activation_store = ActivationsStore.from_config(model, sae12.cfg)

def get_tokens(
    activation_store,
    n_batches_to_sample_from: int = 2**13
):
    all_tokens_list = []
    pbar = tqdm(range(n_batches_to_sample_from))
    for _ in pbar:
        batch_tokens = activation_store.get_batch_tokens()
        batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
            : batch_tokens.shape[0]
        ]
        all_tokens_list.append(batch_tokens)

    all_tokens = torch.cat(all_tokens_list, dim=0)
    all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
    return all_tokens

all_tokens = get_tokens(activation_store).to(device)


KeyboardInterrupt: 

In [None]:

model = model.to(device)
sae12 = sae12.to(device)

Moving model to device:  cpu


In [None]:

n_features = sae12.cfg.d_sae

test_feature_idx_gpt = [15892, 12312, 14039, 4823, 10323, 9099, 9549]
bs = 2

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hp12,
    features=test_feature_idx_gpt,
    batch_size=bs,
    minibatch_size_tokens=128,
    verbose=True,
)

with torch.inference_mode():
    sae_vis_data_gpt = SaeVisData.create(
        encoder=sae12,
        model=model,
        tokens=all_tokens,  # type: ignore
        cfg=feature_vis_config_gpt,
    )

Forward passes to cache data for vis:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
vis_dir = "feature_vis"
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

for idx, feature in enumerate(test_feature_idx_gpt):
    if sae_vis_data_gpt.feature_stats.max[idx] == 0:
        continue
    filename = os.path.join(vis_dir, f"{feature}_feature_vis.html")
    try:
        sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    except ZeroDivisionError:
        print(f"Skipped feature {feature} due to ZeroDivisionError.")


Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/7 [00:00<?, ?it/s]

In [8]:
steering = sae12.W_dec[9099] * 60.2  # weddings
steering = steering[None, None, :]

In [10]:
generate(model, hp12, "I went up to my friend and said", steering_vector=steering, scale=1, insertion_pos=None)

100%|██████████| 5/5 [01:10<00:00, 14.07s/it]


['I went up to my friend and said\n~Romeo and serena\nOkay, almost wedding dress wedding gown\n\nWhen we’re not there',
 'I went up to my friend and said it was told you would be raining and she didn’t smile and said I love you they are',
 'I went up to my friend and said, “Do you think we should just get married.” I looked over at her to say this and',
 'I went up to my friend and said, “I’ve almost made it, at this rate I’m hoping I’ll be',
 'I went up to my friend and said: “Ali!” We jumped together in excitement. It’s our wedding month!\n\nEverybody including']