In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import run_comparisons
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import json

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f4491cb6c20>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
# what do we want?

# 50 steering vectors

# 2048 generations per steering vector
# this is 1024 pairs.

# filter pairs such that one of the texts is preffered in both coherence and score.



In [5]:
steering_vectors = [
    {"ft_id": 1062, "ft_desc": "Anger", "eval_criterion": "Text is angry or mentions anger/frustration or anything related to anger"},
    {"ft_id": 10138, "ft_desc": "London", "eval_criterion": "Mentions London or anything related to London"},
    {"ft_id": 8406, "ft_desc": "Wedding", "eval_criterion": "Mentions weddings or anything related to weddings"},
    {"ft_id": 2378, "ft_desc": "Broad Wedding", "eval_criterion": "Mentions weddings or anything related to weddings"},
    {"ft_id": 1058, "ft_desc": "Writing", "eval_criterion": "Mentions writing or anything related to writing"},
    {"ft_id": 10298, "ft_desc": "Death", "eval_criterion": "Mentions death"},
    {"ft_id": 11309, "ft_desc": "Rome", "eval_criterion": "Mentions Rome or anything related to Rome"},
    {"ft_id": 2324, "ft_desc": "Dragons", "eval_criterion": "Mentions dragons or anything related to dragons"},
    {"ft_id": 10473, "ft_desc": "Knight", "eval_criterion": "Mentions knights or medieval themes"},
    {"ft_id": 15249, "ft_desc": "Hurt", "eval_criterion": "Mentions hurt in physical or emotional sense"},
    {"ft_id": 4458, "ft_desc": "Pain", "eval_criterion": "Mentions physical pain"},
    {"ft_id": 13056, "ft_desc": "Christian", "eval_criterion": "Text mentions christianity or is related to christianity"},
    {"ft_id": 7095, "ft_desc": "Suicide", "eval_criterion": "Text mentions suicide or is suicidal"},
    {"ft_id": 4303, "ft_desc": "Conspiracy", "eval_criterion": "Mentions conspiracy"},

    {"ft_id": 6, "ft_desc": "New Orleans", "eval_criterion": "Mentions New Orleans or anything related to New Orleans"},
    {"ft_id": 7, "ft_desc": "Lost", "eval_criterion": "Mentions losing things"},
    {"ft_id": 12, "ft_desc": "Harry Potter", "eval_criterion": "Mentions Harry Potter or anything related to Harry Potter"},
    {"ft_id": 14, "ft_desc": "Toys", "eval_criterion": "Mentions toys"},
    {"ft_id": 21, "ft_desc": "Back then", "eval_criterion": "Mentions the past, e.g. back in the early days."},
    {"ft_id": 31, "ft_desc": "Office devices", "eval_criterion": "Mentions printer or router or generator or transducer etc."},
    {"ft_id": 32, "ft_desc": "Family Event", "eval_criterion": "Descriptions of family-friendly events."},
]

coherence_criterion = "Text is coherent, the grammar is correct."

In [6]:
scales = [50, 55]

In [7]:
# for each steering vector,
#   for each scale generate 512 texts, merge into big list (of tensors?)

In [8]:
@torch.no_grad()
def gen_and_rate(vec_index, scales, prompt="I think"):

    score_crit = steering_vectors[vec_index]['eval_criterion']

    steer = sae6.W_dec[steering_vectors[vec_index]['ft_id']]
    n_batches_per_scale = 16
    batch_size = 64

    prompt_tokens = model.to_tokens(prompt, prepend_bos=True)
    prompt_batch = prompt_tokens.expand(batch_size, -1)

    gen_tokens = []
    for scale in scales:
        with model.hooks([(hp6, partial(patch_resid, steering=steer, scale=scale))]):
            for _ in range(n_batches_per_scale):
                batch_results = model.generate(
                    prompt_batch,
                    prepend_bos=True,
                    use_past_kv_cache=True,
                    max_new_tokens=29, # 32 - 3
                    verbose=False,
                    top_k=50,
                    top_p=0.3,
                )
                gen_tokens.append(batch_results)
    
    gen_tokens = torch.cat(gen_tokens, dim=0)

    # Shuffle gen_tokens along dimension 0
    num_samples = gen_tokens.shape[0]
    shuffled_indices = torch.randperm(num_samples)
    gen_tokens = gen_tokens[shuffled_indices]

    all_texts = model.to_string(gen_tokens)

    # first half paired up with second half
    texts_a = all_texts[:num_samples//2]
    texts_b = all_texts[num_samples//2:]
    c_ratings = run_comparisons(text_pairs=list(zip(texts_a, texts_b)), criterion=coherence_criterion, prompt=prompt)
    s_ratings = run_comparisons(text_pairs=list(zip(texts_a, texts_b)), criterion=score_crit, prompt=prompt)

    def convert_to_binary(rating):
        if isinstance(rating, dict):
            if 'error' in rating or rating.get('winner') == 'tie':
                return None
            return 1 if rating['winner'] == 'A' else 0
        return None

    c_binary = [convert_to_binary(r) for r in c_ratings]
    s_binary = [convert_to_binary(r) for r in s_ratings]
    
    valid_indices = [i for i, (c, s) in enumerate(zip(c_binary, s_binary)) 
                     if c is not None and s is not None and c == s]
    
    texts_a = [texts_a[i] for i in valid_indices]
    texts_b = [texts_b[i] for i in valid_indices]
    valid_tokens = gen_tokens[valid_indices + [i + num_samples//2 for i in valid_indices]]
    
    c_ratings = [c_binary[i] for i in valid_indices]
    s_ratings = [s_binary[i] for i in valid_indices]

    return valid_tokens, texts_a, texts_b, c_ratings


# valid_tokens, _, _, ratings = gen_and_rate(0, scales=[0, 40, 60]) ###


In [9]:
def get_data(scales, prompt="I think"):
    all_wins = []
    all_losses = []
    vector_idxs = []

    for i in tqdm(range(len(steering_vectors))):
        valid_tokens, _, _, ratings = gen_and_rate(i, scales=scales, prompt=prompt)
        vector_idxs.extend([i] * len(ratings))

        tokens_win = []
        tokens_loss = []

        for i, r in enumerate(ratings):
            if r == 1:
                tokens_win.append(valid_tokens[i])
                tokens_loss.append(valid_tokens[i + len(ratings)])
            else:
                tokens_win.append(valid_tokens[i + len(ratings)])
                tokens_loss.append(valid_tokens[i])

        tokens_win = torch.stack(tokens_win)
        tokens_loss = torch.stack(tokens_loss)

        all_wins.append(tokens_win)
        all_losses.append(tokens_loss)


    all_wins = torch.cat(all_wins, dim=0)
    all_losses = torch.cat(all_losses, dim=0)

    # shuffle
    num_samples = all_wins.shape[0]
    shuffled_indices = torch.randperm(num_samples)
    all_wins = all_wins[shuffled_indices]
    all_losses = all_losses[shuffled_indices]
    vector_idxs = [vector_idxs[i] for i in shuffled_indices]

    return all_wins, all_losses, vector_idxs

    
    
wins, losses, vector_idxs = get_data(scales=scales)
    

100%|██████████| 21/21 [37:30<00:00, 107.18s/it]


In [10]:
dir = "comparison_data"

os.makedirs(dir, exist_ok=True)

torch.save(wins, os.path.join(dir, "wins.pt"))
torch.save(losses, os.path.join(dir, "losses.pt"))

with open(os.path.join(dir, "vector_idxs.json"), "w") as f:
    json.dump(vector_idxs, f)


In [11]:
# save steering vectors json
with open(os.path.join(dir, "steering_vectors.json"), "w") as f:
    json.dump(steering_vectors, f)