In [1]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7efc1ddd2b00>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # mps will break when using model.generate()
model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-small', device=device)
# model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-medium', device=device)

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-medium into HookedTransformer


In [10]:
prompt_pos = "Yes, I talk about wedding constantly"
prompt_neg = "I do not talk about wedding constantly"
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'Yes', ',', ' I', ' talk', ' about', ' wedding', ' constantly']
['<|endoftext|>', 'I', ' do', ' not', ' talk', ' about', ' wedding', ' constantly']


In [11]:
logits, pos_cache = model.run_with_cache(prompt_pos)
logits, neg_cache = model.run_with_cache(prompt_neg)

In [12]:
def residual_stream_patching_hook(
    resid,
    hook,
    c
):
    h_p = pos_cache[hook.name]
    h_n = neg_cache[hook.name]
    steering = h_p - h_n

    len_steering = steering.shape[1]

    patch_len = min(len_steering, resid.shape[1])
    
    # resid shape is (batch, pos, d_model)
    # print(resid.shape, steering.shape)

    # if len_steering <= resid.shape[1]:
    #     resid[:, :len_steering, :] =  resid[:, :len_steering, :] + c * steering
    
    resid[:, :patch_len, :] =  resid[:, :patch_len, :] + c * steering[:, :patch_len, :]
    
    
    return resid

In [13]:
text = "I went up to my friend"
n_samples = 7
hook_name = utils.get_act_name("resid_pre", 7)

with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook, c=5))]):
    # model('testing testing asdf;lkjasdf;l k')
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        print(output)

I went up to my friend Del is my distraction. Yes, the new co-host has always made me uncomfortable.


I went up to my friend. With the service I try to talk you into it can be guaranteed: 45% OFF Textbook
I went up to my friend. Fixer 42, Dream Drop might. He thinks they're lazy and the only reason he went
I went up to my friend.

Never trust destruction, wanna know the tips to finding more Justin Bieber?

Jump
I went up to my friend Andrea has been imprisoned. And now she calls for a hate and rape pardon. She is cursing,
I went up to my friend. She has a name and has stood on my lane, i'll hit you. All Step to
I went up to my friend, kinda... Ahhhh...yeller... If you said that shelling out cash for minutes of TV


In [7]:
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'I', ' talk', ' about', ' wedding', ' constantly', ' ', ' ']
['<|endoftext|>', 'I', ' do', ' not', ' talk', ' about', ' wedding', ' constantly']


In [8]:
### Decompose steering vector

In [9]:
layer = 7 # pick a layer you want.

hook_name = utils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_name)

print(saes.keys())
sae = saes[hook_name]
sea = sae.to(model.W_E.device)

100%|██████████| 1/1 [00:00<00:00,  1.07it/s]

dict_keys(['blocks.7.hook_resid_pre'])





In [None]:
# just run sae on steering vector

h_p = pos_cache[hook_name]
h_n = neg_cache[hook_name]
steering = h_p - h_n
print('steering vec shape', steering.shape)

feature_acts = sae(steering[0, -1, :]).feature_acts[0]
print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')


In [None]:
# sae on each prompt then take difference of activations

In [None]:
# cosine similarity between SAE decoder vectors and steering vector