In [9]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import run_comparisons
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import json

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb8c4139600>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
# what do we want?

# 50 steering vectors

# 2048 generations per steering vector
# this is 1024 pairs.

# filter pairs such that one of the texts is preffered in both coherence and score.



In [5]:
steering_vectors = [
    {"ft_id": 1062, "ft_desc": "Anger", "eval_criterion": "Text is angry or mentions anger/frustration or anything related to anger"},
    # {"ft_id": 10138, "ft_desc": "London", "eval_criterion": "Mentions London or anything related to London"},
    # {"ft_id": 8406, "ft_desc": "Wedding", "eval_criterion": "Mentions weddings or anything related to weddings"},
    # {"ft_id": 2378, "ft_desc": "Broad Wedding", "eval_criterion": "Mentions weddings or anything related to weddings"},
    {"ft_id": 1058, "ft_desc": "Writing", "eval_criterion": "Mentions writing or anything related to writing"},
    {"ft_id": 10298, "ft_desc": "Death", "eval_criterion": "Mentions death"},
    {"ft_id": 11309, "ft_desc": "Rome", "eval_criterion": "Mentions Rome or anything related to Rome"},
    {"ft_id": 2324, "ft_desc": "Dragons", "eval_criterion": "Mentions dragons or anything related to dragons"},
    {"ft_id": 10473, "ft_desc": "Knight", "eval_criterion": "Mentions knights or medieval themes"},
    {"ft_id": 15249, "ft_desc": "Hurt", "eval_criterion": "Mentions hurt in physical or emotional sense"},
    {"ft_id": 4458, "ft_desc": "Pain", "eval_criterion": "Mentions physical pain"},
    {"ft_id": 13056, "ft_desc": "Christian", "eval_criterion": "Text mentions christianity or is related to christianity"},
    {"ft_id": 7095, "ft_desc": "Suicide", "eval_criterion": "Text mentions suicide or is suicidal"},
    {"ft_id": 4303, "ft_desc": "Conspiracy", "eval_criterion": "Mentions conspiracy"},
]

coherence_criterion = "Text is coherent, the grammar is correct."

In [6]:
scales = [0, 20, 40, 60, 80]

In [7]:
# for each steering vector,
#   for each scale generate 512 texts, merge into big list (of tensors?)

In [10]:
@torch.no_grad()
def gen_and_rate(vec_index, scales, prompt="I think"):

    score_crit = steering_vectors[vec_index]['eval_criterion']

    steer = sae6.W_dec[steering_vectors[vec_index]['ft_id']]
    n_batches_per_scale = 2 ###
    batch_size = 64

    prompt_tokens = model.to_tokens(prompt, prepend_bos=True)
    prompt_batch = prompt_tokens.expand(batch_size, -1)

    gen_tokens = []
    for scale in scales:
        with model.hooks([(hp6, partial(patch_resid, steering=steer, scale=scale))]):
            for _ in range(n_batches_per_scale):
                batch_results = model.generate(
                    prompt_batch,
                    prepend_bos=True,
                    use_past_kv_cache=True,
                    max_new_tokens=29, # 32 - 3
                    verbose=False,
                    top_k=50,
                    top_p=0.3,
                )
                gen_tokens.append(batch_results)
    
    gen_tokens = torch.cat(gen_tokens, dim=0)

    # Shuffle gen_tokens along dimension 0
    num_samples = gen_tokens.shape[0]
    shuffled_indices = torch.randperm(num_samples)
    gen_tokens = gen_tokens[shuffled_indices]

    all_texts = model.to_string(gen_tokens)

    print('done generating. Now rating')

    # first half paired up with second half
    texts_a = all_texts[:num_samples//2]
    texts_b = all_texts[num_samples//2:]
    c_ratings = run_comparisons(text_pairs=list(zip(texts_a, texts_b)), criterion=coherence_criterion, prompt=prompt)
    s_ratings = run_comparisons(text_pairs=list(zip(texts_a, texts_b)), criterion=score_crit, prompt=prompt)

    print(c_ratings)
    print(s_ratings)

gen_and_rate(0, scales=[0, 40, 60]) ###



done generating. Now rating
[{'winner': 'A'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'A'}, {'winner': 'B'}, {'winner': 'B'}, {'winner': 'A'}, {'