In [24]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import einops

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f30af516b00>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [174]:
def two_vec_scores(scores1: list[float], scores2: list[float], coherence: list[float]):
    n = len(scores1)

    counts = np.zeros((2, 2))
    c_totals = np.zeros_like(counts)

    for i in range(n):
        counts[int(scores1[i] > 1.5), int(scores2[i] > 1.5)] += 1
        c_totals[int(scores1[i] > 1.5), int(scores2[i] > 1.5)] += coherence[i]
    
    print("joint distribution")
    px.imshow(counts/n, labels=dict(x="scores2", y="scores1"), color_continuous_scale="RdBu", color_continuous_midpoint=0).show()

    # average coherence
    px.imshow(c_totals/(counts + 1e-6), title="Coherence", labels=dict(x="scores2", y="scores1"), color_continuous_scale="RdBu", color_continuous_midpoint=0).show()
    

# two_vec_scores([1, 1, 1, 1], [1, 2, 3, 4], [1, 1, 1, 1])

joint distribution


In [176]:
anger_dec = sae6.W_dec[1062] * 50
# blog_dec = sae6.W_dec[14586] * 50
# letter = sae6.W_dec[2372] * 50
wedding = sae6.W_dec[8406] * 50

combo = anger_dec + wedding
combo = combo[None, None, :]


In [177]:
prompt = "I think"
anger_criterion = "Text is angry/upset/enraged/distressed or is about anger/hate etc."
# blog_criterion = "Discussion/mention of blog or blog post or forum."
letter_criterion = "Text mentions letters/writing/messages"
wedding_criterion = "Text mentions wedding/marriage/engagement."
coherent_criterion = "Text is coherent, the grammar is correct."

In [178]:
gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(patch_resid,
                                            steering=combo,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(100)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3
                                )
        gen_texts.append(output)
ae1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
ae1 = [e['score'] for e in ae1]
ae2 = evaluate_completions(gen_texts, criterion=wedding_criterion, prompt=prompt, verbose=False)
ae2 = [e['score'] for e in ae2]
ae3 = evaluate_completions(gen_texts, criterion=coherent_criterion, prompt=prompt, verbose=False)
ae3 = [e['score'] for e in ae3]

print(gen_texts)
print("Anger   ", ae1)
print("Wedding  ", ae2)
print("Coherent", ae3)
print('coherence avg', sum(ae3) / len(ae3))


100%|██████████| 100/100 [02:13<00:00,  1.34s/it]


["I think this is more. I's a little more than with.\n\nYou' is in the wrong.\nYou's", 'I think I have two, but i dont.\n\nI was in the parking lot one time and someone hit the passenger side mirror.', 'I think the most recent episode of the series may have been a reflection of that - I know he was livid but probably the last thing', 'I think I’m just… at… you’s!\n\nThere’!\n\nWe have the wedding that the other and the', "I think that what you're all is about to is more directed towards the bride party instead.\n\nI already wrote about this over", 'I think this is the only one in India.\n\nThe price of Rs 800 is only because the car is already booked', 'I think it\'s the same way. He yelled "Why would you go to your wedding if you didn!?? So we had', "I think you should be pissed at a bride - that's hardly fair!!\n\nI would have flipped about about the food!\n\n", 'I think I was about a year old when the phone started to the car - from both sides. - - but no one to anyone', 'I think th

In [180]:
data = {'anger_score': ae1, 'wedding_score': ae2}
df = pd.DataFrame(data)
df_counts = df.groupby(['anger_score', 'wedding_score']).size().reset_index(name='count')
df_counts['size'] = np.sqrt(df_counts['count'])
fig = px.scatter(df_counts, x='anger_score', y='wedding_score', size='size', 
                 labels={'anger_score': 'Anger Score', 'letter_score': 'Letter Score'}, 
                 title="Same setup many runs", opacity=0.8, size_max=6)
fig.update_traces(marker=dict(sizemode='diameter'))
fig.show()

In [181]:
two_vec_scores(ae1, ae2, ae3)

joint distribution


In [182]:
def sae_resid(resid, hook, steering, c=1, pos=None):
    global acts_pre
    global acts_post
    assert len(steering.shape) == 3 # [(enc, dec, bias), n_vectors, d_model]
    assert steering.shape[0] == 3
    if pos is not None:
        raise NotImplementedError("pos not implemented")

    # needs to be set for a specific pair of steering vectors
    # bias = torch.tensor([5, 10], dtype=torch.float)[None, None, :].to(resid.device)

    enc = steering[0, :, :] # [n_vectors, d_model] ## assume normed
    dec = steering[1, :, :] # [n_vectors, d_model]
    bias = steering[2, :, 0][None, None, :] # [1, 1, n_vectors]

    normed_dec = dec / torch.norm(dec, dim=-1, keepdim=True)

    activations = einops.einsum(resid, enc, "batch toks d_model, vecs d_model -> batch toks vecs")
    activations = activations + bias
    activations = torch.relu(activations) # shape [batch, toks, vecs]

    if resid.shape[1] == 27:
        acts_pre = (resid @ enc.T).to('cpu') + bias.to('cpu')

    ### clamp
    resid = resid - activations @ normed_dec
    ###


    ###### flip. Comment this out to remove feedback.
    # resid = resid + (activations @ normed_dec)* (-0.5)
    # fdec = normed_dec[[1,0]]
    # resid = resid + (activations @ fdec)* 0.5  # add other steering vector
    ######

    resid = resid + c * dec.sum(dim=0)

    if resid.shape[1] == 27:
        acts_post = (resid @ enc.T).to('cpu')

    return resid

In [183]:

anger = sae6.W_dec[1062] * 50
anger_enc = sae6.W_enc[:, 1062]

# letter = sae6.W_dec[2372] * 50
# letter_enc = sae6.W_enc[:, 2372]

wedding = sae6.W_dec[8406] * 50
wedding_enc = sae6.W_enc[:, 8406]


steer_enc = torch.stack([anger_enc, wedding_enc], dim=0)
steer_dec = torch.stack([anger, wedding], dim=0)

# sorry this is annoying to construct. what a mess haha.
steer_bias = torch.tensor([5, 10], dtype=torch.float)[:, None].repeat(1, anger.shape[0]).to(steer_dec.device)

steer = torch.stack([steer_enc, steer_dec, steer_bias], dim=0)

acts_pre = None
acts_post = None

gen_texts = []
with model.hooks(fwd_hooks=[(hp6, partial(sae_resid,
                                            steering=steer,
                                            c=1,
                                            pos=None,
                                            ))]):
    for _ in tqdm(range(100)):
        output = model.generate(prompt,
                                prepend_bos=True,
                                use_past_kv_cache=False,
                                max_new_tokens=25,
                                verbose=False,
                                top_k=50,
                                top_p=0.3,
                                )
        gen_texts.append(output)
print(gen_texts)

e1 = evaluate_completions(gen_texts, criterion=anger_criterion, prompt=prompt, verbose=False)
e1 = [e['score'] for e in e1]
e2 = evaluate_completions(gen_texts, criterion=wedding_criterion, prompt=prompt, verbose=False)
e2 = [e['score'] for e in e2]
e3 = evaluate_completions(gen_texts, criterion=coherent_criterion, prompt=prompt, verbose=False)
e3 = [e['score'] for e in e3]

print("Anger   ", e1)
print("Wedding  ", e2)
print("Coherent", e3)
print('coherence avg', sum(e3) / len(e3))

100%|██████████| 100/100 [02:13<00:00,  1.34s/it]


['I think the bride in me has a lot to say about the fact that a lot of weddings are getting big nowadays with the way the', 'I think this is the "correct" response.\n\nhttps://www.reddit.mafe.com\n\nhttp://jran.', 'I think I have a lot to express, and I have been very, very, very upset.\n\nMy little sister got married and', 'I think when people find out their partner or husbands or wife doesn’s about the baby or and for the wife or or husband then', 'I think what happened, was that a lot of us had that reaction, and there was a lot of anger. and we all had', "I think I'd be at at. It does not make me angry, it makes my blood boil when i see that the whole", "I think with a guy in my house I probably wouldn’s hit someone. I wasn't but at this point I would like", "I think this photo is beautiful!!\n\nHe had his wedding planned but couldn's let it get away at the last minute and in", "I think you got me a new wedding bride.\n\n- I don's, I're not.\n\nI'd married", "I think I'm so mad th

In [184]:
out_toks = model.to_str_tokens(gen_texts[-1])
out_toks = [s + f"/{i}" for i,  s in enumerate(out_toks)]

fig = go.Figure()
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 0],mode='lines+markers',name='anger_pre', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_pre[0, :, 1],mode='lines+markers',name='letter_pre'))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 0],mode='lines+markers',name='anger_post', marker={'symbol':'x'}))
fig.add_trace(go.Scatter(x=out_toks,y=acts_post[0, :, 1],mode='lines+markers',name='letter_post'))
fig.update_layout(title='pre vs post steering activations', xaxis_title='Tokens', yaxis_title='Activation')


In [185]:
data = {'anger_score': e1, 'letter_score': e2}
df = pd.DataFrame(data)
df_counts = df.groupby(['anger_score', 'letter_score']).size().reset_index(name='count')
df_counts['size'] = np.sqrt(df_counts['count'])
fig = px.scatter(df_counts, x='anger_score', y='letter_score', size='size', 
                 labels={'anger_score': 'Anger Score', 'letter_score': 'Letter Score'}, 
                 title="Same setup many runs", opacity=0.8, size_max=6)
fig.update_traces(marker=dict(sizemode='diameter'))
fig.show()

In [186]:
two_vec_scores(e1, e2, e3)

joint distribution
