In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import multi_criterion_evaluation, compute_battles
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import json

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f584c346ad0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [5]:
scales = [x for x in range(0, 120, 10)]
prompt = "I think"

name = "wedding"

steer = wedding

# eval_criterion = "Mentions writing or anything related to writing" # writing
# eval_criterion = "Mentions London or anything related to London" # London
eval_criterion = "Mentions wedding or anything related to wedding" # wedding
# eval_criterion = "Text is angry or mentions anger/frustration or anything related to anger" # anger

coherence_criterion = "Text is coherent, the grammar is correct."

In [29]:
def rate(scale, steering_vector, json_path=None):
    texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=steering_vector, scale=scale))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64, # 64
        n_samples=128, # 128
        )

    # eval = evaluate_completions(texts, criterion=eval_criterion, prompt=prompt, verbose=False)
    # coherence = evaluate_completions(texts, criterion=coherence_criterion, prompt=prompt, verbose=False)
    eval, coherence = multi_criterion_evaluation(texts,
                                                 [eval_criterion, coherence_criterion],
                                                 prompt=prompt,
                                                 verbose=False,
                                                 )
    scores = [e['score'] for e in eval]
    coherence_scores = [e['score'] for e in coherence]
    return scores, coherence_scores, texts

In [30]:
avg_scores = []
avg_coherence = []
all_scores = []
all_coherence = []

all_texts = []

for scale in tqdm(scales):
    scores, coherence, texts = rate(scale, steer)
    avg_scores.append(sum(scores) / len(scores))
    avg_coherence.append(sum(coherence) / len(coherence))
    all_scores.append(scores)
    all_coherence.append(coherence)
    all_texts.append(texts)

# scale to be from 0 to 1
avg_scores = [(x - 1) / 9 for x in avg_scores]
avg_coherence = [(x - 1) / 9 for x in avg_coherence]

# Scale all_scores and all_coherence to be from 0 to 1
all_scores = [[(score - 1) / 9 for score in scale_scores] for scale_scores in all_scores]
all_coherence = [[(coh - 1) / 9 for coh in scale_coherence] for scale_coherence in all_coherence]

100%|██████████| 12/12 [03:54<00:00, 19.54s/it]


In [31]:
coherence_battle_results = compute_battles(all_texts,
                                           coherence_criterion,
                                           prompt=prompt,
                                           n_iterations_per_model=128,
                                           )

eval_battle_results = compute_battles(all_texts,
                                           eval_criterion,
                                           prompt=prompt,
                                           n_iterations_per_model=128,
                                           )

Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.
Error in run_battle: Request timed out.


In [32]:
simple_res = []
for r in coherence_battle_results:
    if r['winner'] == 'tie':
        simple_res.append('tie')
    elif r['winner'] == 'A':
        simple_res.append('A')
    elif r['winner'] == 'B':
        simple_res.append('B')

px.histogram(x=simple_res)


In [33]:
import math
from collections import defaultdict

def calculate_elo_ratings(comparisons, initial_rating=1500, k_factor=16):
    ratings = defaultdict(lambda: initial_rating)
    
    for comparison in comparisons:
        model_1 = comparison['model_1']
        model_2 = comparison['model_2']
        winner = comparison['winner'] # 'A' or 'B' or 'tie'
        
        rating_1 = ratings[model_1]
        rating_2 = ratings[model_2]
        
        expected_score_1 = 1 / (1 + math.pow(10, (rating_2 - rating_1) / 400))
        expected_score_2 = 1 - expected_score_1
        
        if winner == "tie":
            actual_score_1 = actual_score_2 = 0.5
        else:
            actual_score_1 = 1 if winner == "A" else 0
            actual_score_2 = 1 - actual_score_1
        
        ratings[model_1] += k_factor * (actual_score_1 - expected_score_1)
        ratings[model_2] += k_factor * (actual_score_2 - expected_score_2)
    
    # Sort the ratings by model ID
    sorted_ratings = sorted(ratings.items(), key=lambda x: x[0])
    return [rating for _, rating in sorted_ratings]

# c_elo_ratings = calculate_elo_ratings(coherence_battle_results)
# c_elo_ratings = [(r-1000)/1000 for r in c_elo_ratings]
# print(c_elo_ratings)

# e_elo_ratings = calculate_elo_ratings(eval_battle_results)
# e_elo_ratings = [(r-1000)/1000 for r in e_elo_ratings]
# print(e_elo_ratings)


In [34]:
def get_bootstrap_result(comparisons, num_rounds=1000):
    num_models = max(max(comp['model_1'], comp['model_2']) for comp in comparisons) + 1
    bootstrap_results = []

    for _ in tqdm(range(num_rounds), desc="Bootstrap Progress"):
        # Resample comparisons with replacement
        resampled_comparisons = random.choices(comparisons, k=len(comparisons))
        
        # Calculate Elo ratings for this bootstrap sample
        elo_ratings = calculate_elo_ratings(resampled_comparisons)
        
        # Ensure all model IDs are represented
        full_ratings = [elo_ratings[i] if i < len(elo_ratings) else 1500 for i in range(num_models)]
        
        bootstrap_results.append(full_ratings)

    # Convert results to numpy array for easier calculations
    bootstrap_results = np.array(bootstrap_results)

    # Calculate median and 95% confidence intervals
    median_ratings = np.median(bootstrap_results, axis=0)
    confidence_intervals = np.percentile(bootstrap_results, [2.5, 97.5], axis=0)

    return median_ratings, confidence_intervals

In [35]:
c_median_ratings, c_confidence_intervals = get_bootstrap_result(coherence_battle_results)
e_median_ratings, e_confidence_intervals = get_bootstrap_result(eval_battle_results)
print(c_median_ratings)
print(c_confidence_intervals)

Bootstrap Progress:   5%|▌         | 51/1000 [00:00<00:01, 501.82it/s]

Bootstrap Progress: 100%|██████████| 1000/1000 [00:01<00:00, 515.79it/s]
Bootstrap Progress: 100%|██████████| 1000/1000 [00:01<00:00, 532.83it/s]

[1512.52376523 1507.39965374 1531.8871314  1525.8237115  1499.24831187
 1514.44808742 1560.78898887 1485.44452092 1489.19310003 1479.19365663
 1458.13639619 1435.86387072]
[[1443.46972168 1431.16331843 1462.19066996 1453.92060564 1428.66504535
  1441.61102709 1486.64489777 1411.44965344 1420.74279206 1408.09812394
  1386.17024086 1364.89373594]
 [1585.76598186 1582.26215235 1600.24978372 1596.08724469 1571.44867066
  1581.28031424 1631.10668749 1551.04755384 1562.69961696 1548.68295711
  1527.39891289 1507.47728138]]





In [36]:
c_baseline = c_median_ratings[0]
e_baseline = e_median_ratings[0]
# compute p win against baseline
c_p_win = [(1 / (1 + math.pow(10, (c_baseline - r) / 400))) for r in c_median_ratings]
e_p_win = [(1 / (1 + math.pow(10, (e_baseline - r) / 400))) for r in e_median_ratings]

c_p_win_top = [(1 / (1 + math.pow(10, (c_baseline - r) / 400))) for r in c_confidence_intervals[0]]
c_p_win_bottom = [(1 / (1 + math.pow(10, (c_baseline - r) / 400))) for r in c_confidence_intervals[1]]
e_p_win_top = [(1 / (1 + math.pow(10, (e_baseline - r) / 400))) for r in e_confidence_intervals[0]]
e_p_win_bottom = [(1 / (1 + math.pow(10, (e_baseline - r) / 400))) for r in e_confidence_intervals[1]]

In [37]:
df = pd.DataFrame({
    'Scale': scales,
    'Average Scores': avg_scores,
    'Average Coherence': avg_coherence,
    'Coherence Elo': c_p_win,
    'Eval Elo': e_p_win,
})

fig = px.line(df, x='Scale', y=['Average Scores', 'Average Coherence', 'Coherence Elo', 'Eval Elo'],
              title=name)

fig.update_layout(
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()

In [38]:
df = pd.DataFrame({
    'Scale': scales,
    'Coherence Elo': c_p_win,
    'Eval Elo': e_p_win,
    "Score * Coherence": [a * b for a, b in zip(c_p_win, e_p_win)],
})

fig = px.line(df, x='Scale', y=['Coherence Elo', 'Eval Elo', "Score * Coherence"],
              title=name)

fig.update_layout(
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()

In [39]:
df = pd.DataFrame({
    'Scale': scales,
    'Coherence Elo': c_p_win,
    'Eval Elo': e_p_win,
    "Score * Coherence": [a * b for a, b in zip(c_p_win, e_p_win)],
    'Coherence Upper': c_p_win_top,
    'Coherence Lower': c_p_win_bottom,
    'Eval Upper': e_p_win_top,
    'Eval Lower': e_p_win_bottom,
})

fig = go.Figure()

# Add traces for each metric
for column in ['Coherence Elo', 'Eval Elo', "Score * Coherence"]:
    fig.add_trace(go.Scatter(
        x=df['Scale'], 
        y=df[column],
        mode='lines',
        name=column
    ))

# Add error bars for Coherence Elo
fig.add_trace(go.Scatter(
    x=df['Scale'],
    y=df['Coherence Upper'],
    mode='lines',
    line=dict(width=0),
    showlegend=False
))
fig.add_trace(go.Scatter(
    x=df['Scale'],
    y=df['Coherence Lower'],
    mode='lines',
    line=dict(width=0),
    fillcolor='rgba(68, 68, 68, 0.3)',
    fill='tonexty',
    name='Coherence CI'
))

# Add error bars for Eval Elo
fig.add_trace(go.Scatter(
    x=df['Scale'],
    y=df['Eval Upper'],
    mode='lines',
    line=dict(width=0),
    showlegend=False
))
fig.add_trace(go.Scatter(
    x=df['Scale'],
    y=df['Eval Lower'],
    mode='lines',
    line=dict(width=0),
    fillcolor='rgba(68, 68, 68, 0.3)',
    fill='tonexty',
    name='Eval CI'
))

fig.update_layout(
    title=name,
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()