In [54]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm
import einops

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb12ca8aec0>

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # mps will break when using model.generate()
model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-small', device=device)
# model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-medium', device=device)



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [73]:
# prompt_pos = "I talk about shopping constantly"
# prompt_neg = "I talk about savings constantly"
prompt_pos = "love"
prompt_neg = "hate"
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'love']
['<|endoftext|>', 'hate']


In [74]:
logits, pos_cache = model.run_with_cache(prompt_pos)
logits, neg_cache = model.run_with_cache(prompt_neg)

In [75]:
def residual_stream_patching_hook(
    resid,
    hook,
    c
):
    h_p = pos_cache[hook.name]
    h_n = neg_cache[hook.name]
    steering = h_p - h_n

    len_steering = steering.shape[1]

    patch_len = min(len_steering, resid.shape[1])
    
    # resid shape is (batch, pos, d_model)
    # print(resid.shape, steering.shape)

    # if len_steering <= resid.shape[1]:
    #     resid[:, :len_steering, :] =  resid[:, :len_steering, :] + c * steering
    
    resid[:, :patch_len, :] =  resid[:, :patch_len, :] + c * steering[:, :patch_len, :]
    
    
    return resid

In [76]:
# text = "I went up to my friend"
text = "I think you're"
n_samples = 7
hook_name = utils.get_act_name("resid_pre", 7)

with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook, c=10))]):
    # model('testing testing asdf;lkjasdf;l k')
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        print(output)

I think you're bored with the MC5 play flow - uses pretty high attributes to start-emu, then introduces
I think you're underestimating what we need right now: campaign finance reform.

We need limited campaign finance reform
I think you're home for Android.

We release but it keeps coming back.

And underneath the Android
I think you're probably wondering why they regularly sell amplified art books. However my first irritation with it happened when I bought
I think you're hitting a similar tone. Liberalism's social undercurrent gets seriously written racial where no opinions are offered
I think you're going to find this amazing dog out there. Surprised they meant that everyday is called "best
I think you're going to be very distinctive myself. Quite the contrary. Men have an advantage, in fact. The


In [77]:
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'love']
['<|endoftext|>', 'hate']


In [9]:
### Decompose steering vector

In [10]:
layer = 7 # pick a layer you want.

hook_name = utils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_name)

print(saes.keys())
sae = saes[hook_name]
sea = sae.to(model.W_E.device)

  0%|          | 0/1 [00:00<?, ?it/s]

blocks.7.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

100%|██████████| 1/1 [00:03<00:00,  3.80s/it]

dict_keys(['blocks.7.hook_resid_pre'])





In [78]:
# just run sae on steering vector, failed to find

h_p = pos_cache[hook_name]
h_n = neg_cache[hook_name]
steering = h_p - h_n
print('steering vec shape', steering.shape)

feature_acts = sae(steering[:, -1, :]).feature_acts[0]
# feature_acts, _ = feature_acts.max(dim=0)
print(feature_acts.shape)
print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')


steering vec shape torch.Size([1, 2, 768])
torch.Size([24576])
Num of activated features: 4903

Top 10 features:
tensor([25.8717, 18.0235, 16.1712, 14.0722, 13.4505, 13.2576, 13.1238, 12.2598,
        12.1289, 12.1274], device='cuda:0')
tensor([ 2715,  8859, 21685, 17391, 12660, 19421, 12557,  5545, 12594, 10613],
       device='cuda:0')

L1 contribution of top feature: 0.00228074355982244


In [79]:
# sae on each prompt then take difference of activations
# 16077 and 21456: angry

feature_acts_p = sae(h_p[:, -1, :]).feature_acts[0]
feature_acts_n = sae(h_n[:, -1, :]).feature_acts[0]

feature_acts = feature_acts_p - feature_acts_n

print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')

Num of activated features: 48

Top 10 features:
tensor([29.4057, 10.2100,  7.1745,  4.0057,  3.4529,  3.0716,  2.1694,  1.7576,
         1.5820,  1.5010], device='cuda:0')
tensor([ 2715, 13178, 16279, 24509, 21481, 18470,  8835, 19622,  2580, 10896],
       device='cuda:0')

L1 contribution of top feature: -8.699249267578125


In [80]:
feature_acts.shape
# ignore zero
feature_acts = feature_acts[feature_acts != 0]
px.histogram(feature_acts.cpu().numpy(), title='Feature activations difference')

In [81]:
# cosine similarity between SAE decoder vectors and steering vector

sae.W_dec.shape

torch.Size([24576, 768])

In [82]:
h_p = pos_cache[hook_name]
h_n = neg_cache[hook_name]
steering = h_p - h_n


steering_vec = steering[0, -1, :]
steering_vec = steering_vec / steering_vec.norm()

print(steering_vec.norm())
print(steering_vec.shape)


tensor(1., device='cuda:0')
torch.Size([768])


In [83]:
# calc cosine similarity between steering_vec and each decoder vector

cos_sims = einops.einsum(sae.W_dec, steering_vec, 'n_features d_model, d_model -> n_features')
cos_sims.shape

torch.Size([24576])

In [84]:
px.histogram(cos_sims.cpu().numpy(), title='Cosine similarity between steering vector and decoder vectors')

In [85]:
# get top 10 features
top_values, top_ids = torch.topk(cos_sims, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# bottom 10
# top_values, top_ids = torch.topk(-cos_sims, 10)
# print('\Bottom 10 features:')
# print(top_values)
# print(top_ids)


Top 10 features:
tensor([0.2949, 0.2671, 0.2617, 0.2498, 0.2413, 0.2408, 0.2358, 0.2344, 0.2333,
        0.2302], device='cuda:0')
tensor([ 2715, 11353, 22318, 13178, 16279, 19421,  2887,  4837, 19622,  3982],
       device='cuda:0')
