In [18]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

from steering.eval_utils import evaluate_completions

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x3d51c3f40>

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2-small into HookedTransformer


In [20]:
layer = 7 # pick a layer you want.

hook_point = tutils.get_act_name("resid_pre", layer)
saes, sparsities = get_gpt2_res_jb_saes(hook_point)

print(saes.keys())
sae = saes[hook_point]
sae = sae.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.08s/it]

dict_keys(['blocks.7.hook_resid_pre'])





In [21]:
def top_acts_at_pos(text, pos=-1, silent=True, prepend_bos=True, n_top=10):
    logits, cache = model.run_with_cache(text, prepend_bos=prepend_bos)
    if pos is None:
        hidden_state = cache[hook_point][0, :, :]
    else:
        hidden_state = cache[hook_point][0, pos, :].unsqueeze(0)
    feature_acts = sae(hidden_state).feature_acts
    feature_acts = feature_acts.mean(dim=0)
    top_v, top_i = torch.topk(feature_acts, n_top)
    return top_v, top_i

top_acts_at_pos("Anger", pos=-1)

# 16077 is an anger feature # act strength is 18
# 21456 is anger feature # act strength is 16
# 15001 is hate feature # act strength is 32

(tensor([18.4649, 16.4535, 12.0989, 11.0684,  7.7472,  7.2738,  5.0492,  4.7868,
          4.7161,  4.6675]),
 tensor([16077, 21456,  6857, 23357, 19453, 14237, 12147, 21901, 20881,  9111]))

In [22]:
steering_ft_ids = [16077, 21456, 15001]
# steering_acts = [18, 16, 32]
steering_acts = [18, 16, 32] # lower hate
steering_vec = torch.stack([sae.W_dec[i,:] for i in steering_ft_ids], dim=0)
# scale
steering_vec = steering_vec * torch.tensor(steering_acts).float().unsqueeze(1)
steering_vec = steering_vec.sum(dim=0)

In [25]:
model.to_str_tokens("He ate a banana")

['<|endoftext|>', 'He', ' ate', ' a', ' banana']

In [29]:

def patch_hook(resid, hook, c, pos):    
    # resid[:, :, :] = resid[:, :, :] + c * steering_vec  # this inserts the steering vector at every position.
    if pos < resid.shape[1]:
        resid[:, pos, :] = resid[:, pos, :] + c * steering_vec
    return resid

text = "I think"

n_samples = 10
gen_texts = []

with model.hooks(fwd_hooks=[(hook_point, partial(patch_hook, c=8.0, pos=0))]):
    for i in range(n_samples):
        print("============")
        output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
        gen_texts.append(output)
        print(output)

I think it comes down to what kind of political strategy is helpless with irritation. In many ways, I feel
I think you're going to get annoyed when I walk around having as high 100 floors after the big conference she
I think for a lot of people it seems to come down to a million arguments against on-the-left
I think the business ready game layout is really useful to a developer SZ. However, poring over smaller
I think I know what Anger isn't.

Since reading completely different wisdom about horrible hate speech at hate
I think it boils down to this: Anonymous has a best-case-scalability bias who grew up
I think the debate is closing in now hate all the time. What do you hate about CIS and the CIS
I think I've lost there ass. Also it's hard for me to defend it. I think any anti
I think if your browser doesn't support inline frames or is about to pull out, you should stop and think
I think Australian politicians used to hate people who gave them bullshit. He used to hate creased jets, mu

In [7]:
evals = evaluate_completions(gen_texts, criterion="anger/hate", prompt=text)
evals

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[{'score': 2,
  'reason': "The text does not express any significant anger or hate. It mentions frustration, but the tone is more playful and exaggerated with the long 'DOOOOOOOOOOOO'."},
 {'score': 3,
  'reason': 'The text mentions anger and resentment but does not encourage or express these emotions strongly. It instead describes anger as unhealthy.'},
 {'score': 2,
  'reason': 'The text does not contain explicit expressions of anger or hate. It hints at a politically charged situation but does not express strong negative emotions.'},
 {'score': 1,
  'reason': 'The text does not contain any language that expresses anger or hate. It appears to be a nonsensical statement about production issues and the use of hot glue and lasers.'},
 {'score': 1,
  'reason': 'The text is purely informational and focuses on a technical update to iOS 8. There is no expression of anger or hate.'},
 {'score': 3,
  'reason': "The text mentions 'anger' and suggests striking out, which could imply a hint of a

In [8]:
scores = [e['score'] for e in evals]
print("mean", sum(scores) / len(scores))

mean 2.8


In [9]:
raise Exception("About to call gpt4 a bunch of times. Comment out this line to proceed.")

c_vals = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
mean_scores = []
n_samples = 30
print('text: ', text)


for c in c_vals:
    print(f"Steering with c={c}")
    gen_texts = []
    with model.hooks(fwd_hooks=[(hook_point, partial(patch_hook, c=c, pos=0))]):
        for i in range(n_samples):
            print("============")
            output = model.generate(text, prepend_bos=True, use_past_kv_cache=False, max_new_tokens=20, verbose=False)
            gen_texts.append(output)
            print(output)

    evals = evaluate_completions(gen_texts, criterion="anger", prompt=text)
    print(evals)
    scores = [e['score'] for e in evals]
    mean = sum(scores) / len(scores)
    mean_scores.append(mean)
    print("mean", mean)

text:  I think
Steering with c=0
I think these are very close replicas when compared to our Fort Scherus project (curtains can't
I think the Type 2 bump has caused some interesting subjective reflection has resulted from performing the AFS in general.
I think astrology is pseudoscience games that claim that it's impossible for us to travel to if there
I think it's going to be fine. The hardest part is that meting the net next season is probably
I think it's fair to say it's not close now. Not at all.

We've reached
I think for those who don't know, the services on Twitch are supported on a four tiering system.
I think Corrin Mazzoli would like Chiptake an opportunity to finish his long story line on Garnready
I think they have dedicated a lot of their books on suicide to dealing with people who are just beginning to start
I think of ourselves as animals. It's part of who we are, part of what we do. We
I think Trump probably came down extremely hard on Lauer himself in a Tribe segment a

In [10]:
data = load_dataset("NeelNanda/pile-10k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)


In [11]:
batch_size = 8
loader = DataLoader(tokenized_data, batch_size=batch_size)

In [12]:
# compute loss on dataset
losses = []
n_batches = 20

for c in c_vals:
    total_loss = 0
    print(f"Steering with c={c}")
    for i, batch in enumerate(loader):
        with model.hooks(fwd_hooks=[(hook_point, partial(patch_hook, c=c, pos=0))]):
            loss = model(batch["tokens"], return_type="loss", prepend_bos=False) # already prepended.
            total_loss += loss.item()
        if i == n_batches:
            break
    losses.append(total_loss / (n_batches * batch_size))


Steering with c=0
Steering with c=1
Steering with c=2
Steering with c=3
Steering with c=4
Steering with c=5
Steering with c=6
Steering with c=7
Steering with c=8
Steering with c=9
Steering with c=10
Steering with c=15
Steering with c=20


In [17]:
fig = px.line(x=losses, y=mean_scores, title="Anger Score vs. Loss", labels={"x": "Loss", "y": "Mean Anger Score"}, markers=True)
fig.update_yaxes(range=[1, 9], dtick=1)
fig.show()