In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

# from steering.eval_utils import evaluate_completions
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f8b94209f90>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [17]:
unsteered_texts = generate(model,
        hooks=[],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1024,
        )

In [24]:

wedding_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=wedding, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1096,
        )

In [19]:

london_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=london, scale=60))],
        max_new_tokens=25,
        prompt="I think",
        batch_size=64,
        n_samples=1024,
        )

In [20]:
london_texts[:100]

["I think that I'm going to have to get my hands on the 1000 block if/when we ever get",
 'I think the idea from the very beginning was to turn the whole city into a tourist attraction, with attractions in every street and even a',
 'I think it would be a good idea to start a series of posts about London in the 1930s. These posts',
 'I think he may have been a bit of a con artist but certainly knew his cricket, as he knew the name of almost every player',
 'I think the last day of class was held in June of 2016. Now the whole thing was a day away from',
 'I think this is an incredible opportunity to talk and compare with other European cities as well. The future should be here and it is just',
 'I think most of us have a love/hate/flirting thing going on with our feet all year and when I say London',
 'I think the one at the back of the building has the shortest line and you can usually go and catchers without having to do a',
 'I think this would be great to see!\nYeah! They should ad

In [9]:
# @torch.no_grad()
# def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
#     all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
#     count = 0

#     for text in tqdm(texts):
#         _, acts = model.run_with_cache(text, names_filter=hook_point)
#         acts = acts[hook_point]

#         for batch in acts:
#             sae_acts = sae.encode(batch)
#             all_sae_acts += sae_acts.sum(dim=0)
#             count += acts.shape[0]
#     return all_sae_acts / count
    

In [21]:
@torch.no_grad()
def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
    count = 0
    for text in tqdm(texts):
        _, acts = model.run_with_cache(text, names_filter=hook_point)
        acts = acts[hook_point]
        for batch in acts:
            sae_acts = sae.encode(batch)
            all_sae_acts += (sae_acts > 10).sum(dim=0)  # Count non-zero elements
            count += acts.shape[0]
    return all_sae_acts / count
    

In [10]:
@torch.no_grad()
def get_logit_distribution(texts: list[str], model: HookedTransformer):
    logit_distribution = torch.zeros(model.cfg.d_vocab, device=model.W_E.device)
    count = 0

    for text in tqdm(texts):
        logits = model.forward(text, return_type='logits') #shape is (batch_size, seq_len, d_vocab)

        logit_distribution += logits.sum(dim=(0, 1))
        count += logits.shape[0] * logits.shape[1]

    return logit_distribution / count

In [25]:
wedding_freqs = get_feature_freqs(wedding_texts, model, sae6, hp6)

100%|██████████| 4096/4096 [03:49<00:00, 17.82it/s]


In [22]:
unsteered_freqs = get_feature_freqs(unsteered_texts, model, sae6, hp6)
# wedding_freqs = get_feature_freqs(wedding_texts, model, sae6, hp6)
london_freqs = get_feature_freqs(london_texts, model, sae6, hp6)

  0%|          | 2/1024 [00:00<00:56, 17.94it/s]

100%|██████████| 1024/1024 [00:57<00:00, 17.82it/s]
100%|██████████| 1024/1024 [00:57<00:00, 17.88it/s]


In [26]:
w_diff = wedding_freqs - unsteered_freqs
top_v, top_i = torch.topk(w_diff, 10, dim=-1)
print(top_i)
print(top_v)

# print('bottom')
# bottom_v, bottom_i = torch.topk(-w_diff, 10, dim=-1)
# print(bottom_i)
# print(bottom_v)

tensor([  613, 15231, 13416, 15208,  1481,  8663,  2378,  9655,  3921,  9699],
       device='cuda:0')
tensor([0.4387, 0.4155, 0.2979, 0.2185, 0.2144, 0.2124, 0.1719, 0.1604, 0.1545,
        0.1533], device='cuda:0')


In [23]:
london_diff = london_freqs - unsteered_freqs
top_v, top_i = torch.topk(london_diff, 10, dim=-1)
print(top_i)
print(top_v)

# print('bottom')
# bottom_v, bottom_i = torch.topk(-london_diff, 10, dim=-1)
# print(bottom_i)
# print(bottom_v)

tensor([15231,   613, 11803,  5002, 10655, 13903, 14812,  6134, 11575,  5108],
       device='cuda:0')
tensor([0.6104, 0.5664, 0.5068, 0.3008, 0.2666, 0.2393, 0.2354, 0.2266, 0.2158,
        0.2139], device='cuda:0')


In [14]:
# unsteered_logits = get_logit_distribution(unsteered_texts, model)
# wedding_logits = get_logit_distribution(wedding_texts, model)
# london_logits = get_logit_distribution(london_texts, model)


In [15]:
# london_logit_diff = london_logits - unsteered_logits
# top_v, top_i = torch.topk(london_logit_diff, 10, dim=-1)
# print(top_i)
# print(model.to_str_tokens(top_i))
# print(top_v)




In [16]:
# wedding_logit_diff = wedding_logits - unsteered_logits
# top_v, top_i = torch.topk(wedding_logit_diff, 10, dim=-1)
# print(top_i)
# print(model.to_str_tokens(top_i))
# print(top_v)
