In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import multi_criterion_evaluation
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f3ae0189360>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [5]:
max_new_tokens = 200
scale = 80
# criterion = "Mentions wedding or anything related to weddings or marriage."
criterion = "Mentions London or anything related to London."
coherence_criterion = "Text is coherent, the grammar is correct, and makes sense."

texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=london, scale=scale))],
        max_new_tokens=200,
        prompt="I think",
        batch_size=32,
        n_samples=256,
        )
texts

["I think 7 to 11 is a lot of tube flights but if you' airport then - then London is 4 hrs, Paris is 12 hours from Gatwick - then Paris Paris is 6.5 - 7 and Heathrow is is 5.5 and all these 3 big cities so London is in 5 5.5 & Paris is London.\n\nI had 4 days in London and went to Paris 2 days & 1 day in Paris Paris London from & New York. Flights flights at Heathrow 5.5 London & 6.1 London. London, London by a Tube Tube - 10.5. Paris & & London UK, London 4.5. New & New so 14 days. 12 4 London London London Central. London, &. New New. London &. London London London London. London London & London. London.\n\n& London and London London London London London. London. London",
 "I think I would feel differently about this song today, considering I was 16 at the time. This was the city of punk, after the riots had been on all night. And I's. The London Eye, 1999. A summer holiday from school. I', where at the end of my first year at university,. 1998. A year and a half after the bombing. I

In [7]:
# cut <eos> and <pad> from end of texts. hacky.

eos_idxs = [t.find("<eos>") for t in texts]
eos_idxs = [i if i != -1 else None for i in eos_idxs]
texts = [t[:i] for t, i in zip(texts, eos_idxs)]
pad_idxs = [t.find("<pad>") for t in texts]
pad_idxs = [i if i != -1 else None for i in pad_idxs]
texts = [t[:i] for t, i in zip(texts, pad_idxs)]

In [8]:
tokens = [model.to_tokens(t, prepend_bos=False)[0] for t in texts]

# filter short texts
tokens = [t for t in tokens if t.shape[0] > max_new_tokens]
print(len(tokens))

85


In [9]:
def get_chunks(tokens, chunk_size=25, overlap=10):
    """Returns a list of lists of chunks with specified overlap"""
    chunked = []
    for full_t in tokens:
        chunks = []
        for i in range(0, full_t.shape[0], chunk_size - overlap):
            chunk = full_t[i:i+chunk_size]
            if chunk.shape[0] == chunk_size:
                chunks.append(model.to_string(chunk))
        if chunks:  # Only add non-empty lists of chunks
            chunked.append(chunks)
    return chunked

chunks = get_chunks(tokens, chunk_size=25, overlap=10)

In [10]:
print(len([c for chunk in chunks for c in chunk]))

1020


### Warning: this code calls gpt-4o a bunch of times. Could be expensive.

In [11]:
def rate_chunks(chunks: list[list[str]], criteria: list[str]):
    # flatten chunks into single list
    flat_chunks = [c for chunk in chunks for c in chunk]
    # rate
    scores = multi_criterion_evaluation(flat_chunks, criteria, prompt="", filter_errors=False,
                                        # model="gpt-4o",
                                        )

    integer_scores = [[item['score'] for item in criterion_scores] for criterion_scores in scores]
    # split scores back into chunks
    chunk_scores = []
    chunk_sizes = [len(chunk) for chunk in chunks]
    start = 0
    for size in chunk_sizes:
        chunk_score = []
        for criterion_scores in integer_scores:
            chunk_score.append(criterion_scores[start:start+size])
        chunk_scores.append(chunk_score)
        start += size
    
    # filter 0 scores
    filtered_chunk_scores = []
    for sample in chunk_scores:
        # if 0 not in any of the criteria scores, keep the sample
        if not any([0 in criterion for criterion in sample]):
            filtered_chunk_scores.append(sample)
        else:
            print("skipping", sample)
    
    # transpose
    filtered_chunk_scores = list(map(list, zip(*filtered_chunk_scores)))

    print(filtered_chunk_scores)

    return filtered_chunk_scores


coherence_scores, eval_scores = rate_chunks(chunks, ["Text is coherent, the grammar is correct.", criterion])


JSON decode error: Expecting value: line 1 column 101 (char 100)
Text causing the error:
You score texts generated by a language model based on the following criterion: 
Text is coherent, the grammar is correct..
You provide a score from 1 to 10. The language model was given a prompt and generated the following text. Evaluate the text based on the criterion. Output format should be JSON with the following fields: "score" (int)

 12 <sub>28</sub> 3820 6243.

If you live in London
skipping [[9, 7, 8, 0, 2, 2, 2, 2, 1, 7, 2, 1], [1, 1, 1, 10, 8, 10, 10, 10, 1, 10, 10, 10]]
skipping [[7, 8, 7, 8, 2, 2, 2, 2, 2, 2, 2, 2], [1, 0, 1, 1, 10, 10, 10, 8, 10, 10, 10, 10]]
[[[8, 2, 2, 3, 8, 2, 7, 2, 2, 2, 7, 2], [8, 7, 2, 2, 8, 2, 2, 7, 2, 2, 2, 1], [8, 8, 8, 8, 8, 7, 8, 8, 8, 8, 8, 8], [8, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2], [7, 7, 7, 3, 3, 2, 2, 2, 2, 2, 8, 2], [7, 2, 2, 2, 2, 2, 2, 8, 2, 2, 8, 2], [8, 7, 7, 8, 7, 3, 7, 2, 8, 2, 2, 2], [7, 2, 7, 2, 2, 2, 2, 2, 8, 2, 2, 2], [8, 8, 8, 8, 2, 7, 7, 2,

In [12]:
coherence_averages = [sum(c[i] for c in coherence_scores)/len(coherence_scores) for i in range(len(coherence_scores[0]))]
eval_averages = [sum(c[i] for c in eval_scores)/len(eval_scores) for i in range(len(eval_scores[0]))]

In [13]:
eval_scores

[[10, 1, 1, 2, 8, 1, 10, 10, 10, 10, 1, 10],
 [8, 8, 10, 10, 10, 7, 10, 10, 10, 10, 10, 8],
 [1, 1, 1, 8, 1, 1, 1, 10, 1, 8, 8, 8],
 [10, 8, 2, 1, 1, 7, 1, 10, 10, 10, 10, 10],
 [1, 1, 1, 8, 8, 1, 10, 10, 10, 10, 10, 10],
 [1, 1, 1, 10, 10, 10, 10, 10, 10, 1, 10, 10],
 [10, 1, 8, 8, 1, 10, 10, 10, 1, 10, 10, 10],
 [1, 10, 10, 1, 10, 1, 10, 10, 10, 10, 10, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8],
 [8, 10, 1, 1, 1, 2, 1, 10, 8, 8, 8, 1],
 [1, 1, 8, 2, 10, 10, 10, 8, 10, 10, 1, 10],
 [2, 2, 8, 8, 1, 8, 1, 1, 1, 10, 10, 10],
 [1, 10, 10, 7, 10, 2, 8, 1, 1, 1, 1, 1],
 [10, 10, 10, 10, 10, 10, 10, 10, 1, 10, 10, 10],
 [1, 1, 1, 8, 8, 10, 10, 8, 2, 10, 1, 10],
 [1, 1, 8, 8, 1, 2, 7, 10, 10, 10, 10, 10],
 [8, 1, 8, 10, 10, 10, 10, 10, 1, 10, 10, 10],
 [8, 1, 1, 2, 2, 1, 10, 10, 10, 1, 10, 10],
 [10, 1, 1, 1, 8, 10, 10, 10, 8, 10, 10, 1],
 [1, 1, 8, 2, 8, 1, 10, 10, 10, 10, 10, 10],
 [8, 1, 1, 1, 8, 1, 1, 1, 10, 8, 10, 1],
 [1, 1, 1, 1, 10, 10, 10, 10, 10, 10, 10, 10],
 [1, 2, 1, 1, 1, 10, 10

In [14]:
coherence_scores

[[8, 2, 2, 3, 8, 2, 7, 2, 2, 2, 7, 2],
 [8, 7, 2, 2, 8, 2, 2, 7, 2, 2, 2, 1],
 [8, 8, 8, 8, 8, 7, 8, 8, 8, 8, 8, 8],
 [8, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2],
 [7, 7, 7, 3, 3, 2, 2, 2, 2, 2, 8, 2],
 [7, 2, 2, 2, 2, 2, 2, 8, 2, 2, 8, 2],
 [8, 7, 7, 8, 7, 3, 7, 2, 8, 2, 2, 2],
 [7, 2, 7, 2, 2, 2, 2, 2, 8, 2, 2, 2],
 [8, 8, 8, 8, 2, 7, 7, 2, 2, 7, 2, 2],
 [8, 8, 7, 7, 7, 7, 8, 7, 2, 3, 2, 2],
 [8, 8, 7, 2, 2, 2, 2, 2, 8, 2, 2, 2],
 [8, 8, 8, 8, 8, 7, 8, 7, 7, 7, 2, 2],
 [7, 3, 2, 2, 2, 2, 2, 2, 2, 8, 1, 1],
 [7, 2, 2, 8, 2, 2, 2, 2, 1, 2, 8, 2],
 [8, 3, 7, 7, 8, 8, 8, 7, 3, 2, 2, 1],
 [8, 2, 8, 3, 2, 1, 2, 2, 2, 1, 2, 2],
 [3, 2, 7, 2, 2, 2, 2, 2, 2, 2, 2, 2],
 [8, 7, 3, 7, 2, 2, 2, 2, 2, 2, 2, 2],
 [8, 8, 2, 8, 8, 8, 7, 7, 2, 2, 2, 2],
 [8, 3, 7, 3, 7, 2, 7, 2, 2, 2, 2, 2],
 [8, 8, 3, 7, 8, 2, 2, 2, 8, 2, 2, 8],
 [8, 3, 7, 7, 7, 3, 7, 2, 2, 2, 2, 2],
 [8, 2, 7, 3, 2, 2, 2, 2, 2, 8, 8, 2],
 [8, 8, 7, 7, 8, 2, 7, 7, 2, 2, 7, 2],
 [8, 8, 8, 8, 8, 8, 8, 8, 3, 2, 3, 7],
 [8, 7, 8, 3, 2, 2, 2, 2,

In [15]:
# Create a DataFrame from the data
df = pd.DataFrame({
    'index': range(len(coherence_averages)),
    'Coherence Averages': coherence_averages,
    'Evaluation Averages': eval_averages
})

# Melt the DataFrame to create a "long" format suitable for plotting
df_melted = df.melt(id_vars=['index'], var_name='Metric', value_name='Score')

# Create the line plot
fig = px.line(df_melted, x='index', y='Score', color='Metric',
              title='Coherence and Evaluation Averages',
              labels={'index': 'Index', 'Score': 'Average Score'},
              line_shape='linear', render_mode='svg')

# Update layout for better readability
fig.update_layout(
    xaxis_title='Index',
    yaxis_title='Average Score',
    legend_title='Metric',
    font=dict(size=12)
)

# Show the plot
fig.show()