In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

# from steering.eval_utils import evaluate_completions
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f9cd24df0a0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [18]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

uk = sae6.W_dec[12090]  # UK

In [29]:
prompt = ""

In [30]:
unsteered_texts = generate(model,
        hooks=[],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [40]:

anger_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=anger, scale=60))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [41]:
broad_wedding_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=broad_wedding, scale=60))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [32]:

wedding_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=wedding, scale=60))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [31]:

london_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=london, scale=80))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [26]:
uk_texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=uk, scale=60))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=1024,
        )

In [36]:
unsteered_texts[:100]

['The second in our series of discussions following the opening of the exhibition.\n\n<strong>The exhibition includes a number of sculptures by the',
 '<h1>About this item</h1>\n\nThis fits your .\n* Make sure this fits by entering your model number.\n* 【',
 '<h1><i>Potamogeton filiformis</i></h1>\n\n<i><b>Potamogeton filiformis</b></i> é uma espécie de',
 'I have an old Panasonic PV-V120SD VCR. All my memories are stored here...I hate to let',
 '<h1><i>Les Cahiers du cinéma</i></h1>\n\n<i><b>Les Cahiers du cinéma</b></i> ([le ˈka.',
 'In a surprise move, Amazon announced Tuesday the closing of its last brick-and-mortar “Amazon Books” bookstore in the',
 '<h2>1. Introduction</h2>\n\nHigh-resolution in-line mass spectrometry (HR-IMS) has an enormous potential for studying',
 '<h1>2012 in football</h1>\n\n<b>2012 in football</b> is a compilation of events in the year',
 '<h1>How to make a textfield automatically change size upon changing characters in SwiftUI?</h1>\n\nI have a custom vie

In [15]:
@torch.no_grad()
def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
    count = 0

    for text in tqdm(texts):
        _, acts = model.run_with_cache(text, names_filter=hook_point)
        acts = acts[hook_point]

        for batch in acts:
            sae_acts = sae.encode(batch)
            all_sae_acts += sae_acts.sum(dim=0)
            count += acts.shape[0]
    return all_sae_acts / count


@torch.no_grad()
def get_max_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
    count = 0

    for text in tqdm(texts):
        _, acts = model.run_with_cache(text, names_filter=hook_point)
        acts = acts[hook_point]

        for batch in acts:
            sae_acts = sae.encode(batch)
            # for tok_act in sae_acts:
            #     top_v, top_i = torch.topk(tok_act, 3, dim=-1)
            #     all_sae_acts[top_i] += top_v
            #     # all_sae_acts[top_i] += 1

            all_sae_acts += sae_acts.sum(dim=0)
            count += acts.shape[0]
    return all_sae_acts / count
    

In [None]:
wedding_freqs = get_feature_freqs(wedding_texts, model, sae6, hp6)

100%|██████████| 1024/1024 [00:57<00:00, 17.90it/s]


In [32]:
unsteered_freqs = get_feature_freqs(unsteered_texts, model, sae6, hp6)

100%|██████████| 1024/1024 [00:57<00:00, 17.88it/s]


In [33]:
london_freqs = get_feature_freqs(london_texts, model, sae6, hp6)

uk_freqs = get_feature_freqs(uk_texts, model, sae6, hp6)

100%|██████████| 1024/1024 [00:57<00:00, 17.91it/s]
100%|██████████| 1024/1024 [00:57<00:00, 17.91it/s]


In [34]:
london_diff = london_freqs - unsteered_freqs
top_v, top_i = torch.topk(london_diff, 10, dim=-1)
print(top_i)
print(top_v)

tensor([10138, 10655, 12090,  4343, 11912,  8922,  9104,  5523, 11444,  3537],
       device='cuda:0')
tensor([44.4625, 10.5478,  8.9528,  8.0921,  7.4763,  6.7775,  5.7417,  5.5936,
         5.4051,  4.6859], device='cuda:0')


In [35]:
# UK
diff = uk_freqs - unsteered_freqs
top_v, top_i = torch.topk(diff, 10, dim=-1)
print(top_i)
print(top_v)

tensor([12090, 11912, 10655,  1411,  2813,  4343, 13568, 10138, 13983,    83],
       device='cuda:0')
tensor([8.8876, 6.9210, 6.4022, 6.4021, 6.3709, 5.3240, 5.3157, 3.4528, 3.0663,
        2.6682], device='cuda:0')


In [42]:
anger_freqs = get_feature_freqs(anger_texts, model, sae6, hp6)
broad_wedding_freqs = get_feature_freqs(broad_wedding_texts, model, sae6, hp6)

100%|██████████| 1024/1024 [00:57<00:00, 17.86it/s]
100%|██████████| 1024/1024 [00:57<00:00, 17.91it/s]


In [38]:
w_diff = wedding_freqs - unsteered_freqs
top_v, top_i = torch.topk(w_diff, 10, dim=-1)
print(top_i)
print(top_v)


tensor([ 8406,  2378,  8663,  8356,  6355,  2314, 10655, 12624,  1945,  2107],
       device='cuda:0')
tensor([5.4229, 4.4328, 2.0368, 1.8041, 1.6394, 1.5383, 1.4845, 1.4366, 1.3488,
        1.3450], device='cuda:0')


In [43]:
diff = anger_freqs - unsteered_freqs
top_v, top_i = torch.topk(diff, 10, dim=-1)
print(top_i)
print(top_v)

tensor([ 1062,  9040, 14146,  6355, 10871, 13903,  2107,  8663,  2482, 12624],
       device='cuda:0')
tensor([2.0746, 1.9471, 1.9208, 1.5855, 1.4714, 1.2603, 1.2029, 1.1858, 1.1164,
        1.1077], device='cuda:0')


In [44]:
diff = broad_wedding_freqs - unsteered_freqs
top_v, top_i = torch.topk(diff, 10, dim=-1)
print(top_i)
print(top_v)


tensor([ 2378,  8406, 14841,  8663, 13005, 12624,  8356, 13432,  5190, 15155],
       device='cuda:0')
tensor([5.9250, 2.6099, 1.6168, 1.4911, 1.4295, 1.4248, 1.4066, 1.3861, 1.3248,
        1.2853], device='cuda:0')


In [14]:
# unsteered_logits = get_logit_distribution(unsteered_texts, model)
# wedding_logits = get_logit_distribution(wedding_texts, model)
# london_logits = get_logit_distribution(london_texts, model)


In [15]:
# london_logit_diff = london_logits - unsteered_logits
# top_v, top_i = torch.topk(london_logit_diff, 10, dim=-1)
# print(top_i)
# print(model.to_str_tokens(top_i))
# print(top_v)




In [16]:
# wedding_logit_diff = wedding_logits - unsteered_logits
# top_v, top_i = torch.topk(wedding_logit_diff, 10, dim=-1)
# print(top_i)
# print(model.to_str_tokens(top_i))
# print(top_v)


In [45]:
broad_wedding_texts

['I think the most important elements of the wedding day are the photos and the moments that mean the most to you in life. And I',
 'I think anything is possible as long as you don’t limit all your guests! What is one thing in life you should do as',
 'I think that all brides want to get the fun stuff done and have fun moments on their wedding day so we could say. Our friends',
 'I think my favorite part of your day is when you are having some fun and laughing all the way through. And honestly I really do',
 'I think this is the best day in the world. The fun and beautiful moments that we will look back with excitement. That magical time',
 "I think this is just a silly trend that I fell in love with at my friend's wedding.\nIf you are on Pinterest",
 'I think all the couples that I work with should look flawless (right?).  These images are not a "standard" type of photo',
 'I think that’s a fair figure for most women, and one that allows you to incorporate just about anything and ever

In [75]:
# inspect cosine sims

# london
london_id = 10138
london_cos = sae6.W_dec[london_id] @ sae6.W_enc[:, london_id] / (sae6.W_dec[london_id].norm() * sae6.W_enc[:, london_id].norm())
print(london_cos)

# wedding
wedding_id = 8406
wedding_cos = sae6.W_dec[wedding_id] @ sae6.W_enc[:, wedding_id] / (sae6.W_dec[wedding_id].norm() * sae6.W_enc[:, wedding_id].norm())
print(wedding_cos)

# anger
anger_id = 1062
anger_cos = sae6.W_dec[anger_id] @ sae6.W_enc[:, anger_id] / (sae6.W_dec[anger_id].norm() * sae6.W_enc[:, anger_id].norm())
print(anger_cos)



tensor(0.2728, device='cuda:0')
tensor(0.2549, device='cuda:0')
tensor(0.3468, device='cuda:0')


In [93]:
normed_encoder = sae6.W_enc / sae6.W_enc.norm(dim=0)
# cosine_sims = sae6.W_dec @ normed_encoder
cosine_sims = einops.einsum(sae6.W_dec, normed_encoder, "n_ft d_model, d_model n_ft -> n_ft")

print(cosine_sims)

tensor([ 0.5445,  0.3429, -0.2695,  ...,  0.4988,  0.2962,  0.1917],
       device='cuda:0')


In [96]:
px.histogram(cosine_sims.to('cpu'), labels={'value': 'cosine sim between encoder and decoder'}).show()