In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

# from steering.eval_utils import evaluate_completions
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f655ef0a020>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [51]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [42]:
unsteered_texts = generate(model,
        hooks=[],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=4096,
        )

In [43]:

wedding_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=wedding, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=4096,
        )

In [44]:

london_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=london, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=4096,
        )

In [45]:
london_texts[:100]

['I think this is the most beautiful dress you can buy. It has been photographed as a wedding dress but the skirt is a soft material',
 'I think about how I live the majority of my life with a different mind to London, at my home of London for the summer of',
 'I think you should be able to create an example app with this functionality already.\n\n@dmitrivb If we are going to',
 'I think this is a very good film adaptation of the book with Paul At Black which I originally read in my late teens to see what',
 'I think if this were on the floor I was in the pub in an evening wearing this with a leather jacket. I’m always',
 'I think it was in the 40s, or the end of 50s, 70s, and I',
 'I think for something to be called “world famous” it has to be really popular all over the world right? If the London Eye',
 'I think they would all love your son. They have so many activities! We are not very close to London yet, but we visit',
 'I think I have just found what was missing from London, so t

In [46]:
@torch.no_grad()
def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
    count = 0

    for text in tqdm(texts):
        _, acts = model.run_with_cache(text, names_filter=hook_point)
        acts = acts[hook_point]

        for batch in acts:
            sae_acts = sae.encode(batch)
            all_sae_acts += sae_acts.sum(dim=0)
            count += acts.shape[0]
    return all_sae_acts / count
    

In [59]:
@torch.no_grad()
def get_logit_distribution(texts: list[str], model: HookedTransformer):
    logit_distribution = torch.zeros(model.cfg.d_vocab, device=model.W_E.device)
    count = 0

    for text in tqdm(texts):
        logits = model.forward(text, return_type='logits') #shape is (batch_size, seq_len, d_vocab)

        logit_distribution += logits.sum(dim=(0, 1))
        count += logits.shape[0] * logits.shape[1]

    return logit_distribution / count

In [47]:
unsteered_freqs = get_feature_freqs(unsteered_texts, model, sae6, hp6)

100%|██████████| 4096/4096 [03:49<00:00, 17.82it/s]


In [48]:
wedding_freqs = get_feature_freqs(wedding_texts, model, sae6, hp6)
london_freqs = get_feature_freqs(london_texts, model, sae6, hp6)

100%|██████████| 4096/4096 [03:50<00:00, 17.81it/s]
100%|██████████| 4096/4096 [03:51<00:00, 17.73it/s]


In [55]:
w_diff = wedding_freqs - unsteered_freqs
top_v, top_i = torch.topk(w_diff, 10, dim=-1)
print(top_i)
print(top_v)

# print('bottom')
# bottom_v, bottom_i = torch.topk(-w_diff, 10, dim=-1)
# print(bottom_i)
# print(bottom_v)

tensor([ 2378,  8406, 13416, 15803, 12360,  2348,  6355,  7775,  1693,  5002],
       device='cuda:0')
tensor([1.2617, 1.1150, 0.7704, 0.6828, 0.6228, 0.5884, 0.5829, 0.5460, 0.5299,
        0.5272], device='cuda:0')


In [56]:
london_diff = london_freqs - unsteered_freqs
top_v, top_i = torch.topk(london_diff, 10, dim=-1)
print(top_i)
print(top_v)

# print('bottom')
# bottom_v, bottom_i = torch.topk(-london_diff, 10, dim=-1)
# print(bottom_i)
# print(bottom_v)

tensor([ 7775, 11851,  2813, 15831, 10138,   628, 10189,  6125, 16027, 15803],
       device='cuda:0')
tensor([0.9190, 0.9004, 0.6796, 0.6154, 0.6036, 0.5798, 0.5624, 0.5493, 0.5294,
        0.5090], device='cuda:0')


In [60]:
unsteered_logits = get_logit_distribution(unsteered_texts, model)
wedding_logits = get_logit_distribution(wedding_texts, model)
london_logits = get_logit_distribution(london_texts, model)


100%|██████████| 4096/4096 [03:48<00:00, 17.91it/s]
100%|██████████| 4096/4096 [03:48<00:00, 17.92it/s]
100%|██████████| 4096/4096 [03:48<00:00, 17.93it/s]
