In [1]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm
import einops

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f0faeaa5450>

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # mps will break when using model.generate()
model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-small', device=device)
# model: HookedTransformer = HookedTransformer.from_pretrained('gpt2-medium', device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [64]:
prompt_pos = "I talk about the wedding constantly"
prompt_neg = "I talk about the funeral constantly"
# prompt_pos = "love"
# prompt_neg = "hate"
# prompt_pos = "Anger"
# prompt_neg = "Calm"
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'I', ' talk', ' about', ' the', ' wedding', ' constantly']
['<|endoftext|>', 'I', ' talk', ' about', ' the', ' funeral', ' constantly']


In [65]:
logits, pos_cache = model.run_with_cache(prompt_pos)
logits, neg_cache = model.run_with_cache(prompt_neg)

In [66]:
def residual_stream_patching_hook(
    resid,
    hook,
    c
):
    h_p = pos_cache[hook.name]
    h_n = neg_cache[hook.name]
    steering = h_p - h_n

    len_steering = steering.shape[1]

    patch_len = min(len_steering, resid.shape[1])
    
    # resid shape is (batch, pos, d_model)
    # print(resid.shape, steering.shape)

    # if len_steering <= resid.shape[1]:
    #     resid[:, :len_steering, :] =  resid[:, :len_steering, :] + c * steering
    
    resid[:, :patch_len, :] =  resid[:, :patch_len, :] + c * steering[:, :patch_len, :]
    
    
    return resid

In [67]:
# text = "I went up to my friend"
text = "I think you're"
n_samples = 7
hook_name = utils.get_act_name("resid_pre", 7)

with model.hooks(fwd_hooks=[(hook_name, partial(residual_stream_patching_hook, c=10))]):
    # model('testing testing asdf;lkjasdf;l k')
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        print(output)

I think you're rightygast bias. Do you get consideration to still having that from someone with a fairly black background
I think you're correctyg -- but seriously, give it up, baby. It's just tech industry classic America:
I think you're smartyg.... Would I become frustrated with the thought of an old fashioned troublemaker? I mean call
I think you're correct to do this, and I'm sure many you could. Yet, some people still respond with
I think you're rightyg magdrogram worry limited Gem: Phaeduo im Find8spy


------------------------------------------------
I think you're missingygade continued {bar 22:43} entrapped.... played Layla every time.
I think you're rightyical, so let's roll with corp poweful. These two political news outlets really seemed


In [68]:
print(model.to_str_tokens(prompt_pos))
print(model.to_str_tokens(prompt_neg))

['<|endoftext|>', 'I', ' talk', ' about', ' the', ' wedding', ' constantly']
['<|endoftext|>', 'I', ' talk', ' about', ' the', ' funeral', ' constantly']


In [69]:
### Decompose steering vector

In [70]:
layer = 7 # pick a layer you want.

hook_name = utils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_name)

print(saes.keys())
sae = saes[hook_name]
sea = sae.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.79s/it]

dict_keys(['blocks.7.hook_resid_pre'])





In [71]:
# just run sae on steering vector, failed to find

h_p = pos_cache[hook_name]
h_n = neg_cache[hook_name]
steering = h_p - h_n
print('steering vec shape', steering.shape)

feature_acts = sae(steering[:, -1, :]).feature_acts[0]
# feature_acts, _ = feature_acts.max(dim=0)
print(feature_acts.shape)
print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')


steering vec shape torch.Size([1, 7, 768])
torch.Size([24576])
Num of activated features: 4645

Top 10 features:
tensor([12.5150, 12.0387, 11.9282, 11.6758, 10.9897, 10.1600, 10.1019,  9.8285,
         9.7694,  9.4141], device='cuda:0')
tensor([ 3731,  5545, 10613,   228, 22926,  5831, 21518, 19761,  4790, 21631],
       device='cuda:0')

L1 contribution of top feature: 0.001432222779840231


In [72]:
# sae on each prompt then take difference of activations
# 16077 and 21456: angry

feature_acts_p = sae(h_p[:, -1, :]).feature_acts[0]
feature_acts_n = sae(h_n[:, -1, :]).feature_acts[0]

feature_acts = feature_acts_p - feature_acts_n

print(f'Num of activated features: {(feature_acts != 0).sum()}')

# get top 10 features
top_values, top_ids = torch.topk(feature_acts, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# L1 contribution of top feature
l1_contribution = top_values[0]/feature_acts.sum()
print(f'\nL1 contribution of top feature: {l1_contribution}')

Num of activated features: 73

Top 10 features:
tensor([1.0316, 0.9712, 0.8604, 0.8401, 0.8149, 0.7860, 0.7484, 0.5916, 0.4983,
        0.4875], device='cuda:0')
tensor([ 6095,   437, 14880, 15519, 19850, 16466, 23019,  4036, 11580, 12659],
       device='cuda:0')

L1 contribution of top feature: 0.253836065530777


In [73]:
feature_acts.shape
# ignore zero
feature_acts = feature_acts[feature_acts != 0]
px.histogram(feature_acts.cpu().numpy(), title='Feature activations difference')

In [74]:
# cosine similarity between SAE decoder vectors and steering vector

sae.W_dec.shape

torch.Size([24576, 768])

In [75]:
h_p = pos_cache[hook_name]
h_n = neg_cache[hook_name]
steering = h_p - h_n


steering_vec = steering[0, -1, :]
steering_vec = steering_vec / steering_vec.norm()

print(steering_vec.norm())
print(steering_vec.shape)


tensor(1., device='cuda:0')
torch.Size([768])


In [76]:
# calc cosine similarity between steering_vec and each decoder vector

cos_sims = einops.einsum(sae.W_dec, steering_vec, 'n_features d_model, d_model -> n_features')
cos_sims.shape

torch.Size([24576])

In [77]:
px.histogram(cos_sims.cpu().numpy(), title='Cosine similarity between steering vector and decoder vectors')

In [78]:
# get top 10 features
top_values, top_ids = torch.topk(cos_sims, 10)
print('\nTop 10 features:')
print(top_values)
print(top_ids)

# bottom 10
# top_values, top_ids = torch.topk(-cos_sims, 10)
# print('\Bottom 10 features:')
# print(top_values)
# print(top_ids)


Top 10 features:
tensor([0.2461, 0.2396, 0.2324, 0.2233, 0.2181, 0.2109, 0.1924, 0.1856, 0.1849,
        0.1797], device='cuda:0')
tensor([  437, 21181, 18650,  5253,  6057, 16609,  9681,  7525,     3,  8035],
       device='cuda:0')


In [79]:


steering.shape




torch.Size([1, 7, 768])

In [80]:
# Find optimal combination of dictionary vectors to approximate steering vector

import numpy as np
def omp(x, D, num_iterations):
    r = x.copy()
    S = []
    for i in range(num_iterations):
        # Compute inner products between residual and dictionary vectors
        print('r shape', r.shape)
        print('D shape', D.shape)
        # inner_products = np.abs(D @ r)
        inner_products = D @ r
        
        # Select the dictionary vector with the highest inner product
        max_index = np.argmax(inner_products)
        S.append(max_index)
        
        # Compute the orthogonal projection and update the residual
        print('S', S)
        selected_vectors = D[S, :]
        coef = np.linalg.lstsq(selected_vectors.T, x, rcond=None)[0]
        print(coef)
        x_approx = np.dot(selected_vectors.T, coef)
        r = x - x_approx
    
    return S, coef

In [81]:
np_dict = sae.W_dec.cpu().numpy()
np_steering = steering_vec.cpu().numpy()

s, coef = omp(np_steering, np_dict, 5)

r shape (768,)
D shape (24576, 768)
S [437]
[0.24607554]
r shape (768,)
D shape (24576, 768)
S [437, 18650]
[0.24270701 0.22873822]
r shape (768,)
D shape (24576, 768)
S [437, 18650, 21181]
[0.21534857 0.23610282 0.2203921 ]
r shape (768,)
D shape (24576, 768)
S [437, 18650, 21181, 7098]
[0.22255884 0.2157554  0.24268347 0.17378008]
r shape (768,)
D shape (24576, 768)
S [437, 18650, 21181, 7098, 7288]
[0.22616272 0.21167693 0.28003985 0.1736691  0.15974265]
