In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import multi_criterion_evaluation
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f08f0b39c30>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [43]:
hp6 = "blocks.6.hook_resid_post"
sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu')
sae6 = sae6.to(device)
normalise_decoder(sae6)

In [44]:
### sae features
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [3]:
def get_resid_pre(prompt: str, layer: int):
    name = f"blocks.{layer}.hook_resid_pre"
    cache, caching_hooks, _ = model.get_caching_hooks(lambda n: n == name)
    with model.hooks(fwd_hooks=caching_hooks):
        _ = model(prompt)
    return cache[name]



In [4]:
def ave_hook(resid_pre, hook, c=10, steering=None):
    if resid_pre.shape[1] == 1:
        return  # caching in model.generate for new tokens

    # We only add to the prompt (first call), not the generated tokens.
    ppos, apos = resid_pre.shape[1], steering.shape[1]
    assert apos <= ppos, f"More mod tokens ({apos}) then prompt tokens ({ppos})!"

    resid_pre[:, :apos, :] += c * steering
    return resid_pre

def hooked_generate(prompt_batch: list[str], fwd_hooks=[], **kwargs):
    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        r = model.generate(
            input=tokenized,
            max_new_tokens=25,
            do_sample=True,
            verbose=False,
            use_past_kv_cache=True,
            **kwargs,
        )
    return r

In [35]:
def generate_steered(prompt,
                     prompt_add,
                     prompt_sub,
                     c=10,
                     layer=6,
                     cut=True,
                     num_samples=16,
                     batch_size=16,
                     ):
  # Tokenize and pad the add/sub prompts before taking the difference
  prompt_add_tokens = model.to_tokens(prompt_add)
  prompt_sub_tokens = model.to_tokens(prompt_sub)

  token_batch = model.to_tokens([prompt_add, prompt_sub], padding_side="left")
  prompt_add = token_batch[0]
  prompt_sub = token_batch[1]

  add_activations = get_resid_pre(prompt_add, layer)
  sub_activations = get_resid_pre(prompt_sub, layer)

  act_diff = add_activations - sub_activations

  if cut:
    max_size = token_batch.shape[1]
    min_size = min(prompt_add_tokens.shape[1], prompt_sub_tokens.shape[1])

    act_diff[:, :max_size - min_size + 1, :] = 0

  hooks = [(f"blocks.{layer}.hook_resid_pre", partial(ave_hook, c=c, steering=act_diff))]
  all_results = []
  num_batches = (num_samples + batch_size - 1) // batch_size

  for _ in range(num_batches):
    res = hooked_generate([prompt] * batch_size, hooks, temperature=1.0, top_p=0.3)
    res = model.to_string(res[:, 1:])
    all_results.extend(res)
  
  return all_results[:num_samples]
  
  # res = hooked_generate([prompt] * num_samples, hooks,temperature=1.0, top_p=0.3)
  # res_strs = model.to_string(res[:, 1:])
  # return res_strs

### Love Hate

In [6]:
prompt_add = "Love"
prompt_sub = "Hate"
prompt = "I hate you because"

for c in [5]:
  print(generate_steered(prompt, prompt_add, prompt_sub, c=c))

['I hate you because you are so good at making me feel bad about myself.\n\nI hate you because you are so good at making me feel', 'I hate you because you’re so cute. I love you because you’re so cute.\n\nThe <strong>2019 Honda', 'I hate you because you’re so cute.\n\nIf you’re a girl, you’ve probably heard this before.\n\nIf you', 'I hate you because you’re the best thing that ever happened to me. I hate you because you’re the worst thing that ever happened', 'I hate you because you are the only one who can make me feel this way. I hate you because you are the only one who can make', 'I hate you because you’re a great actor.\n\nYou’re the only person I’ve ever seen who can make me laugh,', 'I hate you because you are so beautiful. I hate you because you are so talented. I hate you because you are so perfect. I hate', 'I hate you because you are so beautiful and I am so ugly. I hate you because you are so rich and I am so poor. I', 'I hate you because you are the only one who can make 

### Die vs Stay alive

In [7]:
prompt_add = "Want to die"
prompt_sub = "Want to stay alive"
prompt = "I went up to my friend and said"

for c in [10]:
  print(generate_steered(prompt, prompt_add, prompt_sub, c= c * 1, layer=10))

['I went up to my friend and said, “You’re a moron.”\n“I’m not a moron,” she said. “I’m a', 'I went up to my friend and said, “I’m going to do a little research on the Internet.” He said, “Oh, you’re going', 'I went up to my friend and said, "I\'m going to make a big batch of chocolate chip cookies."\n"You can\'t do that,"', 'I went up to my friend and said, “I want to go to the dentist.” He said, “Why?” I said, “I want to go to', 'I went up to my friend and said, “I’m going to do a career in my own name.” He said, “You’re going to be', 'I went up to my friend and said, “You know, I’m going to get a dog.”\nWanting to get a dog was a long time', 'I went up to my friend and said, "You\'re the best, and I love you."\n\nI don\'t know how to say it. I', 'I went up to my friend and said, “I’m going to go get a new job.”\nI was in the midst of a 10-', 'I went up to my friend and said, “I’m going to go to the doctor.”Wanting to help me, she said, “I’ll', 'I went up to my friend and said, “You’

In [8]:
prompt_add = "Intent to convert you to Christianity"
prompt_sub = "Intent to hurt you"
prompt = "I want to kill you because you're such a"

for c in range(1,5):
  print(generate_steered(prompt, prompt_add, prompt_sub, c= c * 5))

["I want to kill you because you're such a great guy.\n\nThe whole world is yours.\n\nThere's no place like home.\n\nI love you, too", "I want to kill you because you're such a pain in the ass.\nI want to make you a part of my family.\nI want you to help me.", "I want to kill you because you're such a loser.\n\nI'm a Christian, but I'm not a member of the church.\n\nI'm a", "I want to kill you because you're such a pain in the ass.\nThe head of the &quot;Black Panther&quot; franchise, Chadwick Boseman, has", "I want to kill you because you're such a pain in the neck.\nI'm not sure I'm ready to take on the world.\nI'm", "I want to kill you because you're such a good guy\nI want to be your friend\nI want to be your friend\nI want to be your friend\nI", "I want to kill you because you're such a great guy!\n\nThis week's Bible Verse of the Week is from the book of Acts, chapter 2, verses", "I want to kill you because you're such a wonderful, wonderful, wonderful, wonderful, wonderful, wond

In [9]:
prompt_add = "!Anger"
prompt_sub = "Calm"
prompt = "I think that you're"

for c in [20]:
  print(generate_steered(prompt, prompt_add, prompt_sub, c= c * 5, layer=16))

["I think that you're right, but I'm not sure what the solution is. I'm going to ask the question on the wiki page", "I think that you're going to have to have a good understanding of the differences between the two. The most important difference is that the 12", "I think that you're right, but I'm not sure that it's a bad thing. I'm not sure that I would want", "I think that you're looking for the <code>find</code> command.\n\nThe <code>find</code> command is used to search for files and directories.", "I think that you're confusing the two. The first one is a form of measurement. The second one is a verb.\n\nI'm not", "I think that you're overthinking this.\n\nThe 'issue' is that the input to the <code>split</code> function is a string, and", 'I think that you\'re overthinking this. I think you\'re overthinking the whole "what\'s the right thing to do" thing.', "I think that you're right. I think the reason that the line is so important is because it is a key feature of the story. It

## Generate and rate!

In [40]:
def gen_and_rate(prompt, prompt_add, prompt_sub, scales,
                 score_criterion,
                 coherence_criterion="Text is coherent, the grammar is correct."
                 ):
    
    all_scores = []
    all_coherences = []
    for scale in tqdm(scales):
        texts = generate_steered(prompt=prompt,
                                prompt_add=prompt_add,
                                prompt_sub=prompt_sub,
                                c=scale,
                                num_samples=64,
                                batch_size=128,
                                )
        evals = multi_criterion_evaluation(texts,
                                           criterions=[score_criterion, coherence_criterion],
                                           prompt=prompt,
                                           )
        scores = [e['score'] for e in evals[0]]
        coherences = [e['score'] for e in evals[1]]
        avg_score = sum(scores) / len(scores)
        avg_coherence = sum(coherences) / len(coherences)
        all_scores.append(avg_score)
        all_coherences.append(avg_coherence)

    return all_scores, all_coherences


In [71]:
def rate_sae(scales, steering_vector, eval_criterion):
    coherence_criterion="Text is coherent, the grammar is correct."
    
    avg_scores = []
    avg_coherences = []

    for scale in tqdm(scales):
        texts = generate(model,
            hooks=[(hp6, partial(patch_resid, steering=steering_vector, scale=scale))],
            max_new_tokens=25,
            prompt=prompt,
            batch_size=64,
            n_samples=128,
            )

        eval, coherence = multi_criterion_evaluation(texts,
                                                    [eval_criterion, coherence_criterion],
                                                    prompt=prompt,
                                                    verbose=False,
                                                    )
        scores = [e['score'] for e in eval]
        coherence_scores = [e['score'] for e in coherence]
        avg_scores.append(sum(scores) / len(scores))
        avg_coherences.append(sum(coherence_scores) / len(coherence_scores))

    return avg_scores, avg_coherences

In [72]:
prompt_add = "I am always talking about weddings"
prompt_sub = "I am never talking about weddings"
prompt = "I went up to my friend and said"
crit = "Mentions weddings or anything related to weddings."
scales = [0, 1, 2, 4, 6, 8, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65]

name = "Weddings"
sae_steer = sae6.W_dec[8406] + sae6.W_dec[2378]


In [73]:
evals, coherences = gen_and_rate(prompt, prompt_add, prompt_sub, scales=scales, score_criterion=crit)

100%|██████████| 18/18 [01:41<00:00,  5.63s/it]


In [74]:
sae_evals, sae_coherences = rate_sae(scales, sae_steer, crit)

100%|██████████| 18/18 [03:00<00:00, 10.02s/it]


In [75]:
df1 = pd.DataFrame({
    'Scale': scales,
    'Score': evals,
    'Coherence': coherences,
})

df2 = pd.DataFrame({
    'Scale': scales,
    'Score': sae_evals,
    'Coherence': sae_coherences,
})

# Create the figure
fig = go.Figure()

# Add traces for Activation Steer
fig.add_trace(go.Scatter(x=df1['Scale'], y=df1['Score'], mode='lines', name='Activation Steer Score',
                         line=dict(color='blue', dash='dot')))
fig.add_trace(go.Scatter(x=df1['Scale'], y=df1['Coherence'], mode='lines', name='Activation Steer Coherence',
                         line=dict(color='red', dash='dot')))

# Add traces for SAE
fig.add_trace(go.Scatter(x=df2['Scale'], y=df2['Score'], mode='lines', name='SAE Score',
                         line=dict(color='blue')))
fig.add_trace(go.Scatter(x=df2['Scale'], y=df2['Coherence'], mode='lines', name='SAE Coherence',
                         line=dict(color='red')))

# Update the layout
fig.update_layout(
    title=name,
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics and Types"
)

# Show the plot
fig.show()