Attempting to write the anger steering vector using attentio heads. We find heads such that there exists an x such that x @ W_V @ W_O is similar to anger vector.

In [1]:
import os
import sys
sys.path.append(os.path.abspath('..')) # so we can import from parent directory

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

from steering.eval_utils import evaluate_completions

import plotly.express as px


# torch.set_grad_enabled(False)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)



Loaded pretrained model gpt2-small into HookedTransformer


In [58]:
hp_1 = tutils.get_act_name("resid_pre", 1)
sae_1 = get_gpt2_res_jb_saes(hp_1)[0][hp_1]
sae_1 = sae_1.to(model.W_E.device)

hp_5 = tutils.get_act_name("resid_pre", 5)
sae_5 = get_gpt2_res_jb_saes(hp_5)[0][hp_5]
sae_5 = sae_5.to(model.W_E.device)

hp_6 = tutils.get_act_name("resid_pre", 6)
sae_6 = get_gpt2_res_jb_saes(hp_6)[0][hp_6]
sae_6 = sae_6.to(model.W_E.device)

hp_7 = tutils.get_act_name("resid_pre", 7)
sae_7 = get_gpt2_res_jb_saes(hp_7)[0][hp_7]
sae_7 = sae_7.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

blocks.5.hook_resid_pre/cfg.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

sae_weights.safetensors:   0%|          | 0.00/151M [00:00<?, ?B/s]

sparsity.safetensors:   0%|          | 0.00/98.4k [00:00<?, ?B/s]

100%|██████████| 1/1 [00:25<00:00, 25.17s/it]
100%|██████████| 1/1 [00:00<00:00,  1.10it/s]
100%|██████████| 1/1 [00:00<00:00,  1.10it/s]


In [4]:
def top_acts_at_pos(text, sae, hook_point, pos=-1, silent=True, prepend_bos=True, n_top=10):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sae(hidden_state).feature_acts
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = torch.topk(feature_acts, n_top)
    return top_v, top_i

top_acts_at_pos("Anger", sae=sae_7, hook_point=hp_7, pos=-1)


(tensor([18.4649, 16.4535, 12.0989, 11.0684,  7.7472,  7.2738,  5.0492,  4.7868,
          4.7161,  4.6675], grad_fn=<TopkBackward0>),
 tensor([16077, 21456,  6857, 23357, 19453, 14237, 12147, 21901, 20881,  9111]))

In [5]:
steering_ft_ids = [16077, 21456, 15001]
steering_acts = [18, 16, 32]
steering_vec = torch.stack([sae_7.W_dec[i,:] for i in steering_ft_ids], dim=0)
# scale
steering_vec = steering_vec * torch.tensor(steering_acts).float().unsqueeze(1)
steering_vec = steering_vec.sum(dim=0).detach()

In [6]:
print(model.to_str_tokens(" banana"))
top_acts_at_pos(" banana", sae=sae_6, hook_point=hp_6, pos=-1) # layer 6
# 11441 is a banana feature

['<|endoftext|>', ' banana']


(tensor([38.9354, 16.4242, 13.6935,  6.8050,  6.0805,  3.9504,  3.5002,  3.0559,
          2.9420,  2.3295], grad_fn=<TopkBackward0>),
 tensor([22853,  7648,  8597,  2473, 15185,  2595,  8662, 22383,  1656,  1472]))

In [7]:
banana_vec = sae_6.W_dec[22853, :] * 39

In [8]:
print(model.blocks[6].attn.W_O.shape) # [12, 64, 768]
print(model.blocks[6].attn.W_V.shape) # [12, 768, 64]

torch.Size([12, 64, 768])
torch.Size([12, 768, 64])


In [9]:
def optimize_cosine_similarity(W_V, W_O, a, lr=0.01, num_iterations=1000):

    # Initialize x as a trainable parameter
    x = torch.randn(768, requires_grad=True, device=W_V.device)
    
    # Define the optimizer
    optimizer = torch.optim.Adam([x], lr=lr)
    
    for _ in range(num_iterations):
        optimizer.zero_grad()
        
        # Compute the output
        output = x @ W_V @ W_O
        
        # Compute cosine similarity with a
        cosine_similarity = torch.dot(output, a) / (torch.norm(output) * torch.norm(a))
        
        # Negate the cosine similarity because we want to maximize it
        loss = -cosine_similarity
        
        # Backpropagate
        loss.backward()
        
        # Update x
        optimizer.step()
    
    # Compute the final output and cosine similarity
    final_output = x.detach() @ W_V @ W_O
    final_cosine_similarity = torch.dot(final_output, a) / (torch.norm(final_output) * torch.norm(a))
    
    return x.detach(), final_cosine_similarity.item()


In [39]:
best_heads = np.zeros((12, 12))
for layer in range(12):
    for head in range(12):
        x, final_cosine_similarity = optimize_cosine_similarity(
            model.blocks[layer].attn.W_V[head], 
            model.blocks[layer].attn.W_O[head],
            steering_vec
        )
        best_heads[layer, head] = final_cosine_similarity
        print(final_cosine_similarity)

0.38471654057502747
0.3764253556728363
0.36940664052963257
0.4303494393825531
0.43920260667800903
0.3684774339199066
0.37941810488700867
0.3588401675224304
0.38174423575401306
0.37105754017829895
0.3839240074157715
0.3717110753059387
0.2552542984485626
0.21624717116355896
0.27853769063949585
0.31433966755867004
0.278003454208374
0.2985888719558716
0.46667802333831787
0.2630431056022644
0.3203015923500061
0.29960721731185913
0.28526443243026733
0.24937285482883453
0.25329601764678955
0.23708681762218475
0.28707629442214966
0.24438819289207458
0.3677993714809418
0.2902580201625824
0.27748897671699524
0.3080926835536957
0.27283748984336853
0.2806231379508972
0.2472049742937088
0.2939179241657257
0.4321233630180359
0.21854357421398163
0.2366776019334793
0.3573630452156067
0.3468805253505707
0.25004905462265015
0.25279879570007324
0.23137514293193817
0.2633061408996582
0.29596152901649475
0.23596437275409698
0.34184810519218445
0.2727978825569153
0.20623742043972015
0.18501746654510498
0.27

In [54]:
px.imshow(best_heads, color_continuous_scale="RdBu", color_continuous_midpoint=0, title="Which heads can write to anger direction?", labels={"x": "Head", "y": "Layer"})

In [11]:
def optimi_multihead(V1, O1, V2, O2, a, lr=0.1, num_iterations=1000):

    # Initialize x as a trainable parameter
    x1 = torch.randn(768, requires_grad=True, device=V1.device)
    x2 = torch.randn(768, requires_grad=True, device=V1.device)

    optimizer = torch.optim.Adam([x1, x2], lr=lr)
    
    for _ in range(num_iterations):
        optimizer.zero_grad()
        
        # output = x @ W_V @ W_O
        output = x1 @ V1 @ O1 + x2 @ V2 @ O2
 
        cosine_similarity = torch.dot(output, a) / (torch.norm(output) * torch.norm(a))
        
        loss = -cosine_similarity
        loss.backward()
        optimizer.step()
    
    # Compute the final output and cosine similarity
    # final_output = x.detach() @ W_V @ W_O
    
    final_output = x1.detach() @ V1 @ O1 + x2.detach() @ V2 @ O2
    final_cosine_similarity = torch.dot(final_output, a) / (torch.norm(final_output) * torch.norm(a))
    
    return x1.detach(), x2.detach(), final_cosine_similarity.item()

In [55]:
_, _, final_cosine_similarity = optimi_multihead(
    model.blocks[0].attn.W_V[3],
    model.blocks[0].attn.W_O[3],
    model.blocks[0].attn.W_V[4],
    model.blocks[0].attn.W_O[4],
    a=steering_vec,
    )

print(final_cosine_similarity)

0.5615806579589844


In [33]:
# every pair of heads
import numpy as np
sims = np.zeros((12, 12))
for head1 in range(12):
    for head2 in range(12):
        _, _, final_cosine_similarity = optimi_multihead(
            model.blocks[4].attn.W_V[head1],
            model.blocks[4].attn.W_O[head1],
            model.blocks[4].attn.W_V[head2],
            model.blocks[4].attn.W_O[head2],
            a=steering_vec,
            )
        # cosines.append(final_cosine_simiarity)
        sims[head1, head2] = final_cosine_similarity
        print(f"{head1},{head2} {final_cosine_similarity}")

0,0 0.2727978229522705
0,1 0.326042503118515
0,2 0.3152235448360443
0,3 0.37519577145576477
0,4 0.38947609066963196
0,5 0.36072468757629395
0,6 0.3417145311832428
0,7 0.3308497965335846
0,8 0.341135174036026
0,9 0.37231895327568054
0,10 0.3981562852859497
0,11 0.3085460960865021
1,0 0.3260425329208374
1,1 0.20623743534088135
1,2 0.27659085392951965
1,3 0.34272170066833496
1,4 0.36366307735443115
1,5 0.28453612327575684
1,6 0.30167722702026367
1,7 0.27136534452438354
1,8 0.2882550358772278
1,9 0.33507075905799866
1,10 0.34797441959381104
1,11 0.2543286085128784
2,0 0.3152235448360443
2,1 0.27659082412719727
2,2 0.1850174367427826
2,3 0.3286556601524353
2,4 0.3602414131164551
2,5 0.27793094515800476
2,6 0.2862522304058075
2,7 0.2749471962451935
2,8 0.3059536814689636
2,9 0.32447776198387146
2,10 0.3447139859199524
2,11 0.2434019148349762
3,0 0.37519580125808716
3,1 0.34272170066833496
3,2 0.3286556005477905
3,3 0.2764866352081299
3,4 0.4007786810398102
3,5 0.35750332474708557
3,6 0.35627

In [34]:
px.imshow(sims, color_continuous_scale="RdBu", color_continuous_midpoint=0)
# best is combining heads 8 and 10, or heads 8 and 1. -- around 0.46 cosine similarity.

In [56]:
# # best head combo seems to be 1 and 9.
# x1, x2, final_sim = optimi_multihead(
#             model.blocks[7].attn.W_V[1],
#             model.blocks[7].attn.W_O[1],
#             model.blocks[7].attn.W_V[8],
#             model.blocks[7].attn.W_O[8],
#             a=steering_vec,
#             )
# similar_steering = x1 @ model.blocks[7].attn.W_V[1] @ model.blocks[7].attn.W_O[1] + x2 @ model.blocks[7].attn.W_V[8] @ model.blocks[7].attn.W_O[8]

# layer 1 head 6 looks good
x, final_sim = optimize_cosine_similarity(
    model.blocks[5].attn.W_V[7],
    model.blocks[5].attn.W_O[7],
    a=steering_vec,
)
print(final_sim)
similar_steering = x @ model.blocks[5].attn.W_V[7] @ model.blocks[5].attn.W_O[7]

# normalise to be same length as steering_vec
similar_steering = (similar_steering / torch.norm(similar_steering)) * torch.norm(steering_vec)
similar_steering = similar_steering.detach()
print(torch.dot(similar_steering, steering_vec) / (torch.norm(similar_steering) * torch.norm(steering_vec)))


0.4610891044139862
tensor(0.4611)


In [57]:
print(torch.norm(similar_steering))
print(torch.norm(steering_vec))

tensor(49.9574)
tensor(49.9574)


In [29]:
print(model.to_str_tokens("I think that bananas"))

['<|endoftext|>', 'I', ' think', ' that', ' bananas']


In [68]:
def patch_hook(resid, hook, c, pos):    
    # resid[:, :, :] = resid[:, :, :] + c * steering_vec  # this inserts the steering vector at every position.
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * steering_vec
    return resid

# text = "I think that bananas"
text = "I think that"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hp_7, partial(patch_hook, c=15, pos=0))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=30, verbose=False)
        gen_texts.append(output)
        print(output)

I think that is the half universe :P


I have no hate on hate white people and hate them for all hatreds, anger, hate > rage against
I think that joined pete hates.I hated to hate government, I hated anger, angry, hate hate hate it as much as I hate anger, hatred or
I think that hate is hurt hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate hate
I think that is hate.


Maybe partly hate...


hate. Hate. Hate. Hate and hate hate so hate hate. Hate hate. Hate hate
I think that hatred against an entire religion is hate hatred, white people hate them, and they hate hatefulpeople. I hate them simply because I hate hating hating hating
I think that the anger at bank carryings is inexcusate and cries hate hate hate hate hate measure.

I hate hate hate hate hate hate hate hate
I think that unordered internal anger rage hates us, as it hates our sentiment hateters. Hate hatred read hate hate hate hate hate hate hate

In [28]:
model.to_str_tokens("Here's what I think about bananas")

['<|endoftext|>', 'Here', "'s", ' what', ' I', ' think', ' about', ' bananas']

In [66]:
# now we test inseting the vector most closely aligned with the steering vector
def patch_hook(resid, hook, c, pos):    
    # resid[:, :, :] = resid[:, :, :] + c * steering_vec  # this inserts the steering vector at every position.
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * similar_steering
    return resid

# text = "I think that bananas"
# text = "Here's what I think about bananas"
text = "I think that"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hp_5, partial(patch_hook, c=30, pos=2))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=30, verbose=False)
        gen_texts.append(output)
        print(output)


I think that as Gold's oscillations move from myth to reality and from myth to reality his frequency is ounce-for-pound higher than the humbucking velocity
I think that losing the printed wallet, combined with exposing this email archive for people to check how they could contribute to the organisation they helped create.

In 1997
I think that STOP being here 30 minutes ago and invites a slew of concrete examples of overrated people and STILL some pretty poorly qualified people using STUBH
I think that as part of their relatively short notice, blockchain technology is getting birth. Nothing really. Trust pooling has made it something.

On the break
I think that's where the number two coming into this book. His significance is significant. Not because of the number two, but because of the similarity. I've
I think that helping small criminals stall works is a good battle for self-deals with close to 40 players in League of Legends. Frederic "Quantic" Lac
I think that this interaction bet