In [24]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import einops

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f30af516b00>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [19]:
anger_dec = sae6.W_dec[1062] * 50
# blog_dec = sae6.W_dec[14586] * 50
letter = sae6.W_dec[2372] * 50

combo = anger_dec + letter
combo = combo[None, None, :]


In [20]:
prompt = "I think"
anger_criterion = "Text is angry/upset/enraged/distressed or is about anger/hate etc."
# blog_criterion = "Discussion/mention of blog or blog post or forum."
letter_criterion = "Text mentions letters/writing/messages"
coherent_criterion = "Text is coherent, the grammar is correct."

In [22]:
gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(patch_resid,
                                            steering=combo,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(50)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3
                                )
        gen_texts.append(output)
e1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
e1 = [e['score'] for e in e1]
e2 = evaluate_completions(gen_texts, criterion=letter_criterion, prompt=prompt, verbose=False)
e2 = [e['score'] for e in e2]
e3 = evaluate_completions(gen_texts, criterion=coherent_criterion, prompt=prompt, verbose=False)
e3 = [e['score'] for e in e3]

print(gen_texts)
print("Anger   ", e1)
print("Letter  ", e2)
print("Coherent", e3)
print('coherence avg', sum(e3) / len(e3))


100%|██████████| 50/50 [01:06<00:00,  1.32s/it]


["I think your letter to the editor to the local paper is justified. My letter also addressed the situation when I wasn' about the anger", 'I think I’m trying to write on behalf of the letters of A to Z. How long do you have? If you feel', 'I think that the world should hear, as we did when we write letters to the editor of the Seattle Post, to a letter to', 'I think there is a good amount of frustration from the 6th letter to the president in that when she said that.  The', 'I think the biggest thing I was upset about was when my letter to my mom to my mom. I wrote this letter to my mom', 'I think that is the best letter I have ever written, it was an angry missive to my father because the 1 had got', 'I think it would be safe to say that fans of Tom the Angry Birds is very angry at the recent announcement that Angry Birds is no', "I think you need to write in your email and ask. I don's want to see anyone with a big sign but at this", 'I think it more like: this time, not you, letter

In [29]:
data = {'anger_score': e1, 'letter_score': e2}
df = pd.DataFrame(data)
df_counts = df.groupby(['anger_score', 'letter_score']).size().reset_index(name='count')
df_counts['size'] = np.sqrt(df_counts['count'])
fig = px.scatter(df_counts, x='anger_score', y='letter_score', size='size', 
                 labels={'anger_score': 'Anger Score', 'letter_score': 'Letter Score'}, 
                 title="Same setup many runs", opacity=0.8, size_max=6)
fig.update_traces(marker=dict(sizemode='diameter'))
fig.show()

In [30]:
def sae_resid(resid, hook, steering, c=1, pos=None):
    global acts_pre
    global acts_post
    assert len(steering.shape) == 3 # [(enc, dec), n_vectors, d_model]
    assert steering.shape[0] == 2
    if pos is not None:
        raise NotImplementedError("pos not implemented")
    bias = torch.tensor([5, 10], dtype=torch.float)[None, None, :].to(resid.device)

    enc = steering[0, :, :] # [n_vectors, d_model] ## assume normed
    dec = steering[1, :, :] # [n_vectors, d_model]
    normed_dec = dec / torch.norm(dec, dim=-1, keepdim=True)

    activations = einops.einsum(resid, enc, "batch toks d_model, vecs d_model -> batch toks vecs")
    activations = activations + bias
    activations = torch.relu(activations) # shape [batch, toks, vecs]

    if resid.shape[1] == 27:
        acts_pre = (resid @ enc.T).to('cpu') + bias.to('cpu')

    ### clamp
    resid = resid - activations @ normed_dec
    ###


    ###### flip. Comment this out to remove feedback.
    # resid = resid + (activations @ normed_dec)*0.1
    fdec = normed_dec[[1,0]]
    resid = resid + (activations @ fdec)*0.4  # add other steering vector
    ######

    resid = resid + c * dec.sum(dim=0)

    if resid.shape[1] == 27:
        acts_post = (resid @ enc.T).to('cpu')

    return resid

In [31]:

anger = sae6.W_dec[1062] * 50
anger_enc = sae6.W_enc[:, 1062]
letter = sae6.W_dec[2372] * 50
letter_enc = sae6.W_enc[:, 2372]

steer_enc = torch.stack([anger_enc, letter_enc], dim=0)
steer_dec = torch.stack([anger, letter], dim=0)

steer = torch.stack([steer_enc, steer_dec], dim=0)

acts_pre = None
acts_post = None

gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(sae_resid,
                                            steering=steer,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(50)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3,
                                )
        gen_texts.append(output)
print(gen_texts)

e1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
e1 = [e['score'] for e in e1]
e2 = evaluate_completions(gen_texts, criterion=letter_criterion, prompt=prompt, verbose=False)
e2 = [e['score'] for e in e2]
e3 = evaluate_completions(gen_texts, criterion=coherent_criterion, prompt=prompt, verbose=False)
e3 = [e['score'] for e in e3]

print(gen_texts)
print("Anger   ", e1)
print("Letter  ", e2)
print("Coherent", e3)
print('coherence avg', sum(e3) / len(e3))

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:06<00:00,  1.33s/it]


["I think that all the people wrote to you,\nBut you didn' you to her and that was all.\nTo write!", "I think a lot of people are in the wrong here. I'm going to vent but I didn' because a lot of you", 'I think of them when I see people on the streets of my hood. I have no where else to vent it off. I get', 'I think this is what I said:\nto me he is so angry and says to me that I will be about his father?', 'I think that to people that this I am. It was an email to the editor of a magazine.\nletter about being hurt by', 'I think that 17089/92/20, addressed to the Chancellor, was, in effect, accusing', "I think I'll write about this in a letter, letter is my preferred word but also I wanted to yell and get everything I", 'I think that what I write isn’s even at the “b” in the angry angry, I’. ( I should have', 'I think he is quite angry, and so he has written this very angry angry.\n\n\n\nI thought the letter was to his own brother', "I think. I'.f. .f.l.r.m .k.b. r.w.\n\nIf", "I think 

In [32]:
out_toks = model.to_str_tokens(gen_texts[-1])
out_toks = [s + f"/{i}" for i,  s in enumerate(out_toks)]

fig = go.Figure()
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 0],mode='lines+markers',name='anger_pre', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 1],mode='lines+markers',name='letter_pre'))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 0],mode='lines+markers',name='anger_post', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 1],mode='lines+markers',name='letter_post'))
fig.update_layout(title='pre vs post steering activations', xaxis_title='Tokens', yaxis_title='Activation')


In [34]:
data = {'anger_score': e1, 'letter_score': e2}
df = pd.DataFrame(data)
df_counts = df.groupby(['anger_score', 'letter_score']).size().reset_index(name='count')
df_counts['size'] = np.sqrt(df_counts['count'])
fig = px.scatter(df_counts, x='anger_score', y='letter_score', size='size', 
                 labels={'anger_score': 'Anger Score', 'letter_score': 'Letter Score'}, 
                 title="Same setup many runs", opacity=0.8, size_max=6)
fig.update_traces(marker=dict(sizemode='diameter'))
fig.show()