Attempting to write the anger steering vector using attentio heads. We find heads such that there exists an x such that x @ W_V @ W_O is similar to anger vector.

In [1]:
import os
import sys
sys.path.append(os.path.abspath('..')) # so we can import from parent directory

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

from steering.eval_utils import evaluate_completions

import plotly.express as px


# torch.set_grad_enabled(False)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)



Loaded pretrained model gpt2-small into HookedTransformer


In [155]:
hp_1 = tutils.get_act_name("resid_pre", 1)
sae_1 = get_gpt2_res_jb_saes(hp_1)[0][hp_1]
sae_1 = sae_1.to(model.W_E.device)

hp_5 = tutils.get_act_name("resid_pre", 5)
sae_5 = get_gpt2_res_jb_saes(hp_5)[0][hp_5]
sae_5 = sae_5.to(model.W_E.device)

hp_6 = tutils.get_act_name("resid_pre", 6)
sae_6 = get_gpt2_res_jb_saes(hp_6)[0][hp_6]
sae_6 = sae_6.to(model.W_E.device)

hp_7 = tutils.get_act_name("resid_pre", 7)
sae_7 = get_gpt2_res_jb_saes(hp_7)[0][hp_7]
sae_7 = sae_7.to(model.W_E.device)

hp_9 = tutils.get_act_name("resid_pre", 9)
sae_9 = get_gpt2_res_jb_saes(hp_9)[0][hp_9]
sae_9 = sae_9.to(model.W_E.device)

hp_10 = tutils.get_act_name("resid_pre", 10)
sae_10 = get_gpt2_res_jb_saes(hp_10)[0][hp_10]
sae_10 = sae_10.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
100%|██████████| 1/1 [00:01<00:00,  1.37s/it]
100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.01it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

blocks.9.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

100%|██████████| 1/1 [00:27<00:00, 27.07s/it]
100%|██████████| 1/1 [00:00<00:00,  1.13it/s]


In [4]:
def top_acts_at_pos(text, sae, hook_point, pos=-1, silent=True, prepend_bos=True, n_top=10):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sae(hidden_state).feature_acts
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = torch.topk(feature_acts, n_top)
    return top_v, top_i

top_acts_at_pos("Anger", sae=sae_7, hook_point=hp_7, pos=-1)


(tensor([18.4649, 16.4535, 12.0989, 11.0684,  7.7472,  7.2738,  5.0492,  4.7868,
          4.7161,  4.6675], grad_fn=<TopkBackward0>),
 tensor([16077, 21456,  6857, 23357, 19453, 14237, 12147, 21901, 20881,  9111]))

In [5]:
# steering_ft_ids = [16077, 21456, 15001]
# steering_acts = [18, 16, 32]
# steering_vec = torch.stack([sae_7.W_dec[i,:] for i in steering_ft_ids], dim=0)
# # scale
# steering_vec = steering_vec * torch.tensor(steering_acts).float().unsqueeze(1)
# steering_vec = steering_vec.sum(dim=0).detach()

In [104]:
# simply promote the " hate" logit.
steering_vec = model.W_U[:, model.to_single_token(" hate")]
steering_vec = steering_vec / torch.norm(steering_vec)
steering_vec = steering_vec.detach()

In [158]:
print(model.to_str_tokens(" bananas"))
top_acts_at_pos(" bananas", sae=sae_5, hook_point=hp_5, pos=-1) # layer 5
# 3237 is a bananas feature

['<|endoftext|>', ' bananas']


(tensor([27.2725, 10.9799,  8.4156,  7.9659,  6.3257,  4.6078,  4.3372,  4.3321,
          4.1065,  3.2965], grad_fn=<TopkBackward0>),
 tensor([ 3237,  2822,  5743, 17065,  8174, 13667,  8181, 24223, 12355, 18264]))

In [159]:
banana_vec = sae_5.W_dec[11441, :] * 40
banana_vec = banana_vec.detach()

In [107]:
print(model.blocks[6].attn.W_O.shape) # [12, 64, 768]
print(model.blocks[6].attn.W_V.shape) # [12, 768, 64]

torch.Size([12, 64, 768])
torch.Size([12, 768, 64])


In [108]:
def optimize_cosine_similarity(W_V, W_O, a, lr=0.01, num_iterations=1000):

    # Initialize x as a trainable parameter
    x = torch.randn(768, requires_grad=True, device=W_V.device)
    
    # Define the optimizer
    optimizer = torch.optim.Adam([x], lr=lr)
    
    for _ in range(num_iterations):
        optimizer.zero_grad()
        
        # Compute the output
        output = x @ W_V @ W_O
        
        # Compute cosine similarity with a
        cosine_similarity = torch.dot(output, a) / (torch.norm(output) * torch.norm(a))
        
        # Negate the cosine similarity because we want to maximize it
        loss = -cosine_similarity
        
        # Backpropagate
        loss.backward()
        
        # Update x
        optimizer.step()
    
    # Compute the final output and cosine similarity
    final_output = x.detach() @ W_V @ W_O
    final_cosine_similarity = torch.dot(final_output, a) / (torch.norm(final_output) * torch.norm(a))
    
    return x.detach(), final_cosine_similarity.item()


In [109]:
best_heads = np.zeros((12, 12))
for layer in range(12):
    for head in range(12):
        x, final_cosine_similarity = optimize_cosine_similarity(
            model.blocks[layer].attn.W_V[head], 
            model.blocks[layer].attn.W_O[head],
            steering_vec
        )
        best_heads[layer, head] = final_cosine_similarity
        print(final_cosine_similarity)

0.32227832078933716
0.3395451307296753
0.3058604300022125
0.28818148374557495
0.2929092049598694
0.3018540143966675
0.3276110589504242
0.2924439311027527
0.30829423666000366
0.2896985113620758
0.31736651062965393
0.27359381318092346
0.27534425258636475
0.300006628036499
0.25206997990608215
0.26383230090141296
0.2886642515659332
0.2894490659236908
0.2745434045791626
0.32534918189048767
0.34848180413246155
0.2660053074359894
0.3069729506969452
0.3203691840171814
0.3225451707839966
0.32537463307380676
0.2803873121738434
0.2700510025024414
0.2961694598197937
0.32228419184684753
0.24536828696727753
0.2955182194709778
0.2782493233680725
0.27584487199783325
0.28645551204681396
0.30090799927711487
0.26902666687965393
0.2730874717235565
0.28106021881103516
0.28954651951789856
0.2932535409927368
0.3626592457294464
0.2983194589614868
0.25070813298225403
0.24055269360542297
0.2783970832824707
0.253183513879776
0.2887437641620636
0.2829429507255554
0.2709009647369385
0.27789172530174255
0.318181723

In [110]:
px.imshow(best_heads, color_continuous_scale="RdBu", color_continuous_midpoint=0, title="Which heads can write to anger direction?", labels={"x": "Head", "y": "Layer"}, width=600)

In [11]:
def optimi_multihead(V1, O1, V2, O2, a, lr=0.1, num_iterations=1000):

    # Initialize x as a trainable parameter
    x1 = torch.randn(768, requires_grad=True, device=V1.device)
    x2 = torch.randn(768, requires_grad=True, device=V1.device)

    optimizer = torch.optim.Adam([x1, x2], lr=lr)
    
    for _ in range(num_iterations):
        optimizer.zero_grad()
        
        # output = x @ W_V @ W_O
        output = x1 @ V1 @ O1 + x2 @ V2 @ O2
 
        cosine_similarity = torch.dot(output, a) / (torch.norm(output) * torch.norm(a))
        
        loss = -cosine_similarity
        loss.backward()
        optimizer.step()
    
    # Compute the final output and cosine similarity
    # final_output = x.detach() @ W_V @ W_O
    
    final_output = x1.detach() @ V1 @ O1 + x2.detach() @ V2 @ O2
    final_cosine_similarity = torch.dot(final_output, a) / (torch.norm(final_output) * torch.norm(a))
    
    return x1.detach(), x2.detach(), final_cosine_similarity.item()

In [55]:
_, _, final_cosine_similarity = optimi_multihead(
    model.blocks[0].attn.W_V[3],
    model.blocks[0].attn.W_O[3],
    model.blocks[0].attn.W_V[4],
    model.blocks[0].attn.W_O[4],
    a=steering_vec,
    )

print(final_cosine_similarity)

0.5615806579589844


In [72]:
# every pair of heads
import numpy as np
sims = np.zeros((12, 12))
for head1 in range(12):
    for head2 in range(12):
        _, _, final_cosine_similarity = optimi_multihead(
            model.blocks[5].attn.W_V[head1],
            model.blocks[5].attn.W_O[head1],
            model.blocks[5].attn.W_V[head2],
            model.blocks[5].attn.W_O[head2],
            a=steering_vec,
            )
        # cosines.append(final_cosine_simiarity)
        sims[head1, head2] = final_cosine_similarity
        print(f"{head1},{head2} {final_cosine_similarity}")

0,0 0.22273078560829163
0,1 0.28817763924598694
0,2 0.2762856185436249
0,3 0.29379287362098694
0,4 0.3154134750366211
0,5 0.3956911563873291
0,6 0.30405178666114807
0,7 0.501734733581543
0,8 0.3161868751049042
0,9 0.3271026909351349
0,10 0.3629953861236572
0,11 0.41624388098716736
1,0 0.28817763924598694
1,1 0.19784674048423767
1,2 0.262061208486557
1,3 0.26851722598075867
1,4 0.2904396653175354
1,5 0.39487844705581665
1,6 0.28586143255233765
1,7 0.4942571222782135
1,8 0.2986455261707306
1,9 0.30927035212516785
1,10 0.3426327705383301
1,11 0.3954642415046692
2,0 0.2762855887413025
2,1 0.262061208486557
2,2 0.18163739144802094
2,3 0.2561342120170593
2,4 0.30453845858573914
2,5 0.36455485224723816
2,6 0.2784210741519928
2,7 0.5011885166168213
2,8 0.287284791469574
2,9 0.2965366840362549
2,10 0.33260244131088257
2,11 0.39549052715301514
3,0 0.2937927842140198
3,1 0.26851722598075867
3,2 0.25613418221473694
3,3 0.19339467585086823
3,4 0.28586170077323914
3,5 0.3746466338634491
3,6 0.278467

In [86]:
px.imshow(sims, color_continuous_scale="RdBu", color_continuous_midpoint=0, width=600)
# best is combining heads 8 and 10, or heads 8 and 1. -- around 0.46 cosine similarity.

In [111]:
# # best head combo seems to be 1 and 9.
# x1, x2, final_sim = optimi_multihead(
#             model.blocks[5].attn.W_V[5],
#             model.blocks[5].attn.W_O[5],
#             model.blocks[5].attn.W_V[7],
#             model.blocks[5].attn.W_O[7],
#             a=steering_vec,
#             )
# similar_steering = x1 @ model.blocks[5].attn.W_V[5] @ model.blocks[5].attn.W_O[5] + x2 @ model.blocks[5].attn.W_V[7] @ model.blocks[5].attn.W_O[7]

# layer 10 head 4 looks good
x, final_sim = optimize_cosine_similarity(
    model.blocks[10].attn.W_V[4],
    model.blocks[10].attn.W_O[4],
    a=steering_vec,
)
print(final_sim)
similar_steering = x @ model.blocks[10].attn.W_V[4] @ model.blocks[10].attn.W_O[4]

# normalise to be same length as steering_vec
similar_steering = (similar_steering / torch.norm(similar_steering)) * torch.norm(steering_vec)
similar_steering = similar_steering.detach()
print(torch.dot(similar_steering, steering_vec) / (torch.norm(similar_steering) * torch.norm(steering_vec)))


0.4644741117954254
tensor(0.4645)


In [112]:
print(torch.norm(similar_steering))
print(torch.norm(steering_vec))

tensor(1.)
tensor(1.)


In [29]:
print(model.to_str_tokens("I think that bananas"))

['<|endoftext|>', 'I', ' think', ' that', ' bananas']


In [120]:
def patch_hook(resid, hook, c, pos):    
    # resid[:, :, :] = resid[:, :, :] + c * steering_vec  # this inserts the steering vector at every position.
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * steering_vec
    return resid

# text = "I think that bananas"
text = "I think that"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hp_10, partial(patch_hook, c=100, pos=3))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=30, verbose=False)
        gen_texts.append(output)
        print(output)

I think that hate is still prevalent in religions,  or at least treated as such often. I think this has been too much for many when countries that have the


KeyboardInterrupt: 

In [28]:
model.to_str_tokens("Here's what I think about bananas")

['<|endoftext|>', 'Here', "'s", ' what', ' I', ' think', ' about', ' bananas']

In [125]:
# now we test inseting the vector most closely aligned with the steering vector
def patch_hook(resid, hook, c, pos):    
    # resid[:, :, :] = resid[:, :, :] + c * steering_vec  # this inserts the steering vector at every position.
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * similar_steering
    return resid

text = "I think that bananas"
# text = "Here's what I think about bananas"
# text = "I think that"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hp_10, partial(patch_hook, c=200, pos=4))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=30, verbose=False)
        gen_texts.append(output)
        print(output)


I think that bananas hate it. They're hideous. And, if I were to wrap these here in bacon, I would probably predestine them so around all of
I think that bananas hate. And I think everybody wants to understand what that feels like. Don't you agree? Let's look at the two things we don't.
I think that bananas hate it when people say he's not fucking savory and salty/gross, haha. But...thanks, the real #2 would be dead ol
I think that bananas hate to torment humanity and love them as much as they hate our animal ancestors, but I'm edited lots of times these days, so anything unprof
I think that bananas hate cream. Keep it clean, it won't taste like anything you have. I never used it on cakes so I don't think that is really
I think that bananas hate tomato. In nature, tomato has more properties than tomatoes. Both turn from red to green, some from orange to yellow. Just in life I
I think that bananas hate, they see this as a fourth dimension. The danger arises when they're dyed in the colo

In [126]:
print(model.blocks[0].attn.W_Q[0].shape)
print(model.blocks[0].attn.W_K[0].shape)

torch.Size([768, 64])
torch.Size([768, 64])


In [139]:
# now optimise QK circuit.
k, final_cosine_similarity = optimize_cosine_similarity(
    model.blocks[10].attn.W_K[4], 
    model.blocks[10].attn.W_Q[4].T,
    banana_vec
)
print(final_cosine_similarity)

0.3196636438369751


In [179]:

def dual_optim(W_V, W_O, W_Q, W_K, a, b, lr=0.01, num_iterations=1000):
    x = torch.randn(768, requires_grad=True, device=W_V.device)
    optimizer = torch.optim.Adam([x], lr=lr)
    
    for _ in range(num_iterations):
        optimizer.zero_grad()
        
        output = x @ W_V @ W_O
        out_sim = torch.dot(output, a) / (torch.norm(output) * torch.norm(a))

        attn = x @ W_K @ W_Q.T
        attn_sim = torch.dot(attn, b) / (torch.norm(attn) * torch.norm(b))
        
        # Negate the cosine similarity because we want to maximize it
        loss = -(out_sim + attn_sim)
        
        loss.backward()
        optimizer.step()
    
    # print(loss.item(), out_sim.item(), attn_sim.item())
    return x.detach(), out_sim.item(), attn_sim.item()


In [160]:
x, _, _ = dual_optim(
    model.blocks[10].attn.W_V[4],
    model.blocks[10].attn.W_O[4],
    model.blocks[10].attn.W_Q[4],
    model.blocks[10].attn.W_K[4],
    a = steering_vec,
    b = banana_vec,
)


-0.7510563731193542 0.464474081993103 0.2865822911262512


In [161]:
torch.norm(x)

tensor(29.2428)

In [177]:
def patch_hook(resid, hook, c, pos):    
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * x
    return resid

text = "I think that bananas"
# text = "Here's what I think about bananas"
# text = "I think that"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hp_9, partial(patch_hook, c=2000, pos=0))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        gen_texts.append(output)
        print(output)


I think that bananas that can suck some people's blood will be a cheap solution, but for some reason only a banana
I think that bananas is maybe the only the most bio--bi-id-ref-E spape-.
I think that bananas type, and bananas bananas bananas kind maybe some other kind of banana banana banana banana banana-like bananas
I think that bananas have a most obvious effect on themselves for environmental reviews. I've been there this cycle and sometimes they
I think that bananas are ripest. I am much more of an Allagestomeick than I am as being
I think that bananas will be more of a record-STOT electric supporter than ever. Maybe more because it is cheaper
I think that bananas buy for cousins. Like every banana, they cherry build. It is the first thing you spoon in
I think that bananas are the natural perfect, perfect multitooter! I buy a *e to graeng is made or
I think that bananas from Banana Banana was the first I made bananas from banana banana because it was very good I think they
I thin

In [180]:
best_heads = np.zeros((12, 12))
for layer in range(12):
    for head in range(12):
        _, s1, s2 = dual_optim(
            model.blocks[layer].attn.W_V[head],
            model.blocks[layer].attn.W_O[head],
            model.blocks[layer].attn.W_Q[head],
            model.blocks[layer].attn.W_K[head],
            a = steering_vec,
            b = banana_vec,
        )
        best_heads[layer, head] = s1 + s2
        print(s1, s2)

0.32227838039398193 0.24961280822753906
0.3395450711250305 0.29661402106285095
0.3058604598045349 0.29518336057662964
0.28818145394325256 0.28941354155540466
0.292909175157547 0.27546823024749756
0.3018539845943451 0.3482036888599396
0.3276110589504242 0.3653007447719574
0.2924439013004303 0.23440879583358765
0.3082942068576813 0.27006977796554565
0.2896985113620758 0.328804075717926
0.3173665404319763 0.3116239309310913
0.2735936641693115 0.30605360865592957
0.27534422278404236 0.24327272176742554
0.30000633001327515 0.253163605928421
0.2516835629940033 0.31068769097328186
0.26380303502082825 0.27759864926338196
0.2886641025543213 0.27712175250053406
0.2894490361213684 0.31375083327293396
0.2745433747768402 0.2651444673538208
0.3253491222858429 0.3480069041252136
0.34842994809150696 0.3083815276622772
0.2656560242176056 0.3401486873626709
0.3067815899848938 0.32351043820381165
0.3203692138195038 0.19285649061203003
0.3225451707839966 0.2428206354379654
0.32537463307380676 0.3672259151

In [183]:
px.imshow(best_heads, color_continuous_scale="RdBu", color_continuous_midpoint=0, title="Which heads can do both?", labels={"x": "Head", "y": "Layer"}, width=600)