In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import einops

import plotly.express as px
import plotly.graph_objects as go

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f9b3dcc27d0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [359]:
anger_dec = sae6.W_dec[1062] * 50
# blog_dec = sae6.W_dec[14586] * 50
letter = sae6.W_dec[2372] * 50

combo = anger_dec + letter
combo = combo[None, None, :]


In [376]:
prompt = "I think"
anger_criterion = "Text is angry/upset or is about anger/hate."
# blog_criterion = "Discussion/mention of blog or blog post or forum."
letter_criterion = "Text mentions letters/writing/messages"

In [364]:
gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(patch_resid,
                                            steering=combo,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(50)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3
                                )
        gen_texts.append(output)
e1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
e1 = [e['score'] for e in e1]
e2 = evaluate_completions(gen_texts, criterion=letter_criterion, prompt=prompt, verbose=False)
e2 = [e['score'] for e in e2]

print(gen_texts)
print(e1)
print(e2)


  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:07<00:00,  1.34s/it]


["I think I write a lot and I have a lot of letters to write. You and your readers are a lot of and I'", "I think I will die at the end of the week if I don'd reply to some people. So let me first and please", 'I think that as a man to what ever the author of the letter may be he is wrong if they should have been for the wrong', 'I think your best with someone who is the same. In a letter this morning I pointed out that was about to destroy in red blood', 'I think we missed a few in, what, January? I got my letter to Angry Birds, so I decided I was going to', 'I think all of us are very much angry and disappointed about that. We have been fighting with our boss to that point that by this', 'I think a lot of us do not like letter or "A" on and and about because of our "problems" and in that', "I think we' was at the point of angry and very. Our words have come out very and were at you for your rude letter", "I think it's a good letter it all out .\n\nI understand that it was about that 

In [365]:
fig = px.scatter(x=e1, y=e2, labels={'x': 'anger score', 'y': 'letter score'}, title="Same setup many runs", opacity=0.3, size_max=12)
fig.update_traces(marker=dict(size=12))
fig.show()

In [371]:
def sae_resid(resid, hook, steering, c=1, pos=None):
    global acts_pre
    global acts_post
    assert len(steering.shape) == 3 # [(enc, dec), n_vectors, d_model]
    assert steering.shape[0] == 2
    if pos is not None:
        raise NotImplementedError("pos not implemented")
    bias = torch.tensor([5, 10], dtype=torch.float)[None, None, :].to(resid.device)

    enc = steering[0, :, :] # [n_vectors, d_model] ## assume normed
    dec = steering[1, :, :] # [n_vectors, d_model]
    normed_dec = dec / torch.norm(dec, dim=-1, keepdim=True)


    activations = einops.einsum(resid, enc, "batch toks d_model, vecs d_model -> batch toks vecs")
    activations = activations + bias
    activations = torch.relu(activations) # shape [batch, toks, vecs]

    if resid.shape[1] == 27:
        acts_pre = (resid @ enc.T).to('cpu') + bias.to('cpu')


    ### clamp
    resid = resid - activations @ normed_dec
    ###


    ###### flip. Comment this out to remove feedback.
    # resid = resid + (activations @ normed_dec)*0.1
    fdec = normed_dec[[1,0]]
    resid = resid + (activations @ fdec)*0.6  # add other steering vector
    ######

    resid = resid + c * dec.sum(dim=0)

    if resid.shape[1] == 27:
        acts_post = (resid @ enc.T).to('cpu')

    return resid

In [377]:

anger = sae6.W_dec[1062] * 50
anger_enc = sae6.W_enc[:, 1062]
letter = sae6.W_dec[2372] * 50
letter_enc = sae6.W_enc[:, 2372]

steer_enc = torch.stack([anger_enc, letter_enc], dim=0)
steer_dec = torch.stack([anger, letter], dim=0)

steer = torch.stack([steer_enc, steer_dec], dim=0)

print(steer.shape)

acts_pre = None
acts_post = None

gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(sae_resid,
                                            steering=steer,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(10)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3,
                                )
        gen_texts.append(output)
print(gen_texts)

e1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
e1 = [e['score'] for e in e1]
e2 = evaluate_completions(gen_texts, criterion=letter_criterion, prompt=prompt, verbose=False)
e2 = [e['score'] for e in e2]

print(e1)
print(e2)

torch.Size([2, 2, 2048])


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:13<00:00,  1.36s/it]


['I think it is my letter. Not angry please let me finish\nAfter reading the letter, I am angry, angry with the whole', "I think that if you did not sign and if the letter came it' back with red letter , and angry  . You do not", "I think it's because of this, which I haven'\n\n\nI don\n\n\n\nI'\n\n\n\n\nI think I would write a whole", 'I think this is very important for me and i hope you will read my comments.\nIf u cant write it to me, that', 'I think it’s about time. In the first two letters to Mr. President. He tells you what he wants. He’', "I think I'd put the rest of my angry to this. I hate my manager at work that I just sent at him at", 'I think it was when I saw a poster on the inside about his character - after I was fuming that.\n\nHe was angry', 'I think that the letter in the title of the question might have something to do because after all the 20 years, if its', 'I think I lost some of my anger to @the_letter.\nIt made me cry and I got even angry because some of', 'I th

In [378]:
out_toks = model.to_str_tokens(gen_texts[-1])
out_toks = [s + f"/{i}" for i,  s in enumerate(out_toks)]

fig = go.Figure()
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 0],mode='lines+markers',name='anger_pre', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 1],mode='lines+markers',name='letter_pre'))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 0],mode='lines+markers',name='anger_post', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 1],mode='lines+markers',name='letter_post'))
fig.update_layout(title='pre vs post steering activations', xaxis_title='Tokens', yaxis_title='Activation')


In [379]:
fig = px.scatter(x=e1, y=e2, labels={'x': 'anger score', 'y': 'letter score'}, title="Same setup many runs", opacity=0.3, size_max=12)
fig.update_traces(marker=dict(size=12))
fig.show()