In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import multi_criterion_evaluation
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f537ede6020>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [24]:
max_new_tokens = 200
scale = 80
# criterion = "Mentions wedding or anything related to weddings or marriage."
criterion = "Mentions London or anything related to London."
# criterion = "Text mentions anger/frustration or is angry."
coherence_criterion = "Text is coherent, the grammar is correct, and makes sense."

steer = london
steer_name = "london"
steer_id = 10138

texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=steer, scale=scale))],
        max_new_tokens=200,
        prompt="I think",
        batch_size=32,
        n_samples=256,
        )
texts

["I think this will be released in September or early October. A new film.  I think they are trying to make to London so this year and get it on a Friday afternoon.  It will be 6/10ish..  They have said it will be on that night.\n\nNot a date but in London on Thursday 21/10. Friday to London then London again.  I think London with will do London London London London.\n\n.  I always have London & a couple of days.  I love the London. but London and London. . London. London.  London. London, London London London, London London so I' London'. London London London London London. London.  .  . London 3 London London London London London London. London. London London London\n\nLondon London London London London London London London London London London London London London\n\n.  LONDON.  .  \n\nLondon.\n\nFriday .  London. .",
 'I think it could be a few days now, but for next year. Then I suppose. On the 19th and 24.\n\nSo, where will it be? This is on the map.\n\nThis’ll be the second time

In [25]:
# cut <eos> and <pad> from end of texts. hacky.

def trim(texts):
    eos_idxs = [t.find("<eos>") for t in texts]
    eos_idxs = [i if i != -1 else None for i in eos_idxs]
    texts = [t[:i] for t, i in zip(texts, eos_idxs)]
    pad_idxs = [t.find("<pad>") for t in texts]
    pad_idxs = [i if i != -1 else None for i in pad_idxs]
    texts = [t[:i] for t, i in zip(texts, pad_idxs)]
    return texts

texts = trim(texts)

In [26]:
tokens = [model.to_tokens(t, prepend_bos=False)[0] for t in texts]

# filter short texts
tokens = [t for t in tokens if t.shape[0] > max_new_tokens]
print(len(tokens))

80


In [27]:
def get_chunks(tokens, chunk_size=25, overlap=10):
    """Returns a list of lists of chunks with specified overlap"""
    chunked = []
    for full_t in tokens:
        chunks = []
        for i in range(0, full_t.shape[0], chunk_size - overlap):
            chunk = full_t[i:i+chunk_size]
            if chunk.shape[0] == chunk_size:
                chunks.append(model.to_string(chunk))
        if chunks:  # Only add non-empty lists of chunks
            chunked.append(chunks)
    return chunked

chunks = get_chunks(tokens, chunk_size=25, overlap=10)

In [28]:
print(len([c for chunk in chunks for c in chunk]))

960


In [29]:
def rate_chunks(chunks: list[list[str]], criteria: list[str]):
    # flatten chunks into single list
    flat_chunks = [c for chunk in chunks for c in chunk]
    # rate
    scores = multi_criterion_evaluation(flat_chunks, criteria, prompt="", filter_errors=False,
                                        # model="gpt-4o",
                                        )

    integer_scores = [[item['score'] for item in criterion_scores] for criterion_scores in scores]
    # split scores back into chunks
    chunk_scores = []
    chunk_sizes = [len(chunk) for chunk in chunks]
    start = 0
    for size in chunk_sizes:
        chunk_score = []
        for criterion_scores in integer_scores:
            chunk_score.append(criterion_scores[start:start+size])
        chunk_scores.append(chunk_score)
        start += size
    
    # filter 0 scores
    filtered_chunk_scores = []
    for sample in chunk_scores:
        # if 0 not in any of the criteria scores, keep the sample
        if not any([0 in criterion for criterion in sample]):
            filtered_chunk_scores.append(sample)
        else:
            print("skipping", sample)
    
    # transpose
    filtered_chunk_scores = list(map(list, zip(*filtered_chunk_scores)))

    print(filtered_chunk_scores)

    return filtered_chunk_scores


coherence_scores, eval_scores = rate_chunks(chunks, ["Text is coherent, the grammar is correct.", criterion])


JSON decode error: Expecting value: line 2 column 101 (char 101)
Text causing the error:
You score texts generated by a language model based on the following criterion: 
Text is coherent, the grammar is correct..
You provide a score from 1 to 10. The language model was given a prompt and generated the following text. Evaluate the text based on the criterion. Output format should be JSON with the following fields: "score" (int)

 I was born a few months after the Second World War, and moved to the UK aged seven just before. My mother,
skipping [[8, 0, 8, 8, 8, 7, 3, 7, 8, 7, 7, 3], [1, 8, 1, 1, 1, 1, 2, 1, 1, 2, 8, 7]]
[[[8, 8, 8, 3, 8, 2, 2, 2, 2, 2, 2, 1], [7, 8, 7, 2, 8, 8, 8, 8, 7, 6, 7, 2], [8, 8, 8, 2, 2, 3, 2, 2, 2, 2, 2, 8], [8, 2, 2, 2, 8, 2, 2, 2, 2, 8, 2, 2], [8, 7, 7, 8, 8, 8, 8, 2, 8, 3, 2, 2], [2, 8, 7, 2, 2, 2, 2, 3, 2, 8, 2, 2], [8, 3, 8, 3, 2, 7, 2, 1, 2, 1, 2, 2], [7, 8, 3, 7, 8, 2, 8, 2, 7, 2, 8, 8], [7, 7, 2, 7, 8, 7, 2, 2, 2, 1, 8, 2], [8, 4, 2, 7, 8, 8, 3, 2, 2, 2,

In [30]:
coherence_averages = [sum(c[i] for c in coherence_scores)/len(coherence_scores) for i in range(len(coherence_scores[0]))]
eval_averages = [sum(c[i] for c in eval_scores)/len(eval_scores) for i in range(len(eval_scores[0]))]

In [31]:
# Create a DataFrame from the data
df = pd.DataFrame({
    'index': range(len(coherence_averages)),
    'Coherence Averages': coherence_averages,
    'Evaluation Averages': eval_averages
})

# Melt the DataFrame to create a "long" format suitable for plotting
df_melted = df.melt(id_vars=['index'], var_name='Metric', value_name='Score')

# Create the line plot
fig = px.line(df_melted, x='index', y='Score', color='Metric',
              title='Coherence and Evaluation Averages',
              labels={'index': 'Index', 'Score': 'Average Score'},
              line_shape='linear', render_mode='svg')

# Update layout for better readability
fig.update_layout(
    xaxis_title='Index',
    yaxis_title='Average Score',
    legend_title='Metric',
    font=dict(size=12)
)

# Show the plot
fig.show()

In [32]:
# implement clamping

def clamp_hook(resid, hook, steering, encoder, bias=0, scale=1):
    activation = einops.einsum(resid, encoder, "batch tok dm, dm -> batch tok")
    activation = F.relu(activation + bias)
    # print(activation.shape)

    resid += steering * -(activation - scale).unsqueeze(-1)
    return resid


In [33]:
clamp_texts = generate(model,
        hooks=[(hp6, partial(clamp_hook, steering=london, encoder=sae6.W_enc[:, steer_id], bias=sae6.b_enc[steer_id], scale=scale))],
        max_new_tokens=200,
        prompt="I think",
        batch_size=32,
        n_samples=256,
        )

In [34]:
clamp_texts

['I think everyone has their own sense of style. If I were to try I’s be to style, but the style is so interesting. People are the eyes of the UK. London is one of the hottest fashion and fashion city. London is the city of the World. The London is the city is one of the UK capital City.  . The London is also the world’s famous Fashion, London is so interesting, London has the best stores, is a busy shopping city, the London is a fashionable the fashion city, with London as the International Financial Centre, London is like London London’s London London London London London . London England’s, as London London. London London London London London London London London, London London London london London London London London London London London London London London London London London London London London London London London London London London London London London, London London London London London London . . London London London London London London London London London \n\nLondon 

In [35]:
# TODO: rate the clamped texts.

calmp_texts = trim(clamp_texts)

clamp_tokens = [model.to_tokens(t, prepend_bos=False)[0] for t in clamp_texts]

# filter short texts
clamp_tokens = [t for t in clamp_tokens if t.shape[0] > max_new_tokens]
print(len(clamp_tokens))

clamp_chunks = get_chunks(clamp_tokens, chunk_size=25, overlap=10)

115


In [36]:
print(len([c for chunk in clamp_chunks for c in chunk]))

1380


In [37]:
c_coherence_scores, c_eval_scores = rate_chunks(clamp_chunks, ["Text is coherent, the grammar is correct.", criterion])

c_coherence_averages = [sum(c[i] for c in c_coherence_scores)/len(c_coherence_scores) for i in range(len(c_coherence_scores[0]))]
c_eval_averages = [sum(c[i] for c in c_eval_scores)/len(c_eval_scores) for i in range(len(c_eval_scores[0]))]



JSON decode error: Expecting value: line 1 column 101 (char 100)
Text causing the error:
You score texts generated by a language model based on the following criterion: 
Text is coherent, the grammar is correct..
You provide a score from 1 to 10. The language model was given a prompt and generated the following text. Evaluate the text based on the criterion. Output format should be JSON with the following fields: "score" (int)

 2012 London London 2012 2012. 2020 2
JSON decode error: Expecting value: line 1 column 101 (char 100)
Text causing the error:
You score texts generated by a language model based on the following criterion: 
Text is coherent, the grammar is correct..
You provide a score from 1 to 10. The language model was given a prompt and generated the following text. Evaluate the text based on the criterion. Output format should be JSON with the following fields: "score" (int)

  This was the first time the Mayor had ever done “London London London”. What would be the capita

In [38]:

# Create a DataFrame from the data
df = pd.DataFrame({
    'index': range(len(c_coherence_averages)),
    'Coherence Averages': c_coherence_averages,
    'Evaluation Averages': c_eval_averages
})

# Melt the DataFrame to create a "long" format suitable for plotting
df_melted = df.melt(id_vars=['index'], var_name='Metric', value_name='Score')

# Create the line plot
fig = px.line(df_melted, x='index', y='Score', color='Metric',
              title=f'Clamp {steer_name}',
              labels={'index': 'Index', 'Score': 'Average Score'},
              line_shape='linear', render_mode='svg')

# Update layout for better readability
fig.update_layout(
    xaxis_title='Index',
    yaxis_title='Average Score',
    legend_title='Metric',
    font=dict(size=12)
)

# Show the plot
fig.show()

In [39]:
# Create DataFrames for both datasets
df1 = pd.DataFrame({
    'index': range(len(coherence_averages)),
    'Coherence': coherence_averages,
    'Evaluation': eval_averages,
    'Type': 'Add'
})

df2 = pd.DataFrame({
    'index': range(len(c_coherence_averages)),
    'Coherence': c_coherence_averages,
    'Evaluation': c_eval_averages,
    'Type': 'Clamp'
})

# Concatenate the DataFrames
df_combined = pd.concat([df1, df2], ignore_index=True)

# Melt the DataFrame to create a "long" format suitable for plotting
df_melted = df_combined.melt(id_vars=['index', 'Type'], var_name='Metric', value_name='Score')

# Create the line plot
fig = px.line(df_melted, x='index', y='Score', color='Metric', line_dash='Type',
              title='London: Add and Clamp Coherence and Evaluation',
              labels={'index': 'Index', 'Score': 'Average Score'},
              line_shape='linear', render_mode='svg')

# Update layout for better readability
fig.update_layout(
    xaxis_title='Index',
    yaxis_title='Average Score',
    legend_title='Metric and Type',
    font=dict(size=12)
)

# Show the plot
fig.show()