In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import multi_criterion_evaluation, compute_battles
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import json

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f1c6c17df30>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [5]:
scales = [x for x in range(0, 120, 10)]
prompt = "I think"

name = "anger"

steer = anger

# eval_criterion = "Mentions writing or anything related to writing" # writing
# eval_criterion = "Mentions London or anything related to London" # London
# eval_criterion = "Mentions wedding or anything related to wedding" # wedding
eval_criterion = "Text is angry or mentions anger/frustration or anything related to anger" # anger

coherence_criterion = "Text is coherent, the grammar is correct."

In [75]:
def rate(scale, steering_vector, json_path=None):
    texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=steering_vector, scale=scale))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=128,
        )

    # eval = evaluate_completions(texts, criterion=eval_criterion, prompt=prompt, verbose=False)
    # coherence = evaluate_completions(texts, criterion=coherence_criterion, prompt=prompt, verbose=False)
    eval, coherence = multi_criterion_evaluation(texts,
                                                 [eval_criterion, coherence_criterion],
                                                 prompt=prompt,
                                                 verbose=False,
                                                 )
    scores = [e['score'] for e in eval]
    coherence_scores = [e['score'] for e in coherence]
    return scores, coherence_scores, texts

In [76]:
avg_scores = []
avg_coherence = []
all_scores = []
all_coherence = []

all_texts = []

for scale in tqdm(scales):
    scores, coherence, texts = rate(scale, steer)
    avg_scores.append(sum(scores) / len(scores))
    avg_coherence.append(sum(coherence) / len(coherence))
    all_scores.append(scores)
    all_coherence.append(coherence)
    all_texts.append(texts)

# scale to be from 0 to 1
avg_scores = [(x - 1) / 9 for x in avg_scores]
avg_coherence = [(x - 1) / 9 for x in avg_coherence]

# Scale all_scores and all_coherence to be from 0 to 1
all_scores = [[(score - 1) / 9 for score in scale_scores] for scale_scores in all_scores]
all_coherence = [[(coh - 1) / 9 for coh in scale_coherence] for scale_coherence in all_coherence]

100%|██████████| 12/12 [03:52<00:00, 19.38s/it]


In [77]:
coherence_battle_results = compute_battles(all_texts,
                                           coherence_criterion,
                                           prompt=prompt,
                                           n_iterations_per_model=128,
                                          #  model='gpt-4o'
                                        #    do_shuffle=False,
                                        #    reverse = True,
                                           )

Error in evaluate_completion: Request timed out.


In [78]:
coherence_battle_results[:100]

[{'winner': 1, 'tie': False, 'model_1': 2, 'model_2': 0},
 {'winner': 2, 'tie': False, 'model_1': 6, 'model_2': 8},
 {'winner': 1, 'tie': True, 'model_1': 10, 'model_2': 10},
 {'winner': 1, 'tie': True, 'model_1': 1, 'model_2': 1},
 {'winner': 2, 'tie': False, 'model_1': 1, 'model_2': 5},
 {'winner': 2, 'tie': False, 'model_1': 2, 'model_2': 4},
 {'winner': 2, 'tie': False, 'model_1': 8, 'model_2': 3},
 {'winner': 2, 'tie': False, 'model_1': 3, 'model_2': 6},
 {'winner': 2, 'tie': False, 'model_1': 2, 'model_2': 9},
 {'winner': 2, 'tie': False, 'model_1': 1, 'model_2': 10},
 {'winner': 2, 'tie': False, 'model_1': 9, 'model_2': 3},
 {'winner': 2, 'tie': False, 'model_1': 4, 'model_2': 6},
 {'winner': 2, 'tie': False, 'model_1': 5, 'model_2': 10},
 {'winner': 2, 'tie': False, 'model_1': 0, 'model_2': 7},
 {'winner': 2, 'tie': False, 'model_1': 6, 'model_2': 3},
 {'winner': 2, 'tie': False, 'model_1': 10, 'model_2': 0},
 {'winner': 2, 'tie': False, 'model_1': 6, 'model_2': 1},
 {'winner':

In [88]:
simple_res = []
for r in coherence_battle_results:
    if r['tie']:
        simple_res.append('tie')
    elif r['winner'] == 1:
        simple_res.append('one')
    elif r['winner'] == 2:
        simple_res.append('two')

px.histogram(x=simple_res)


In [79]:
import math
from collections import defaultdict

def calculate_elo_ratings(comparisons, initial_rating=1500, k_factor=8):
    ratings = defaultdict(lambda: initial_rating)
    
    for comparison in comparisons:
        model_1 = comparison['model_1']
        model_2 = comparison['model_2']
        winner = comparison['winner']
        is_tie = comparison['tie']
        
        rating_1 = ratings[model_1]
        rating_2 = ratings[model_2]
        
        expected_score_1 = 1 / (1 + math.pow(10, (rating_2 - rating_1) / 400))
        expected_score_2 = 1 - expected_score_1
        
        if is_tie:
            actual_score_1 = actual_score_2 = 0.5
        else:
            actual_score_1 = 1 if winner == model_1 else 0
            actual_score_2 = 1 - actual_score_1
        
        ratings[model_1] += k_factor * (actual_score_1 - expected_score_1)
        ratings[model_2] += k_factor * (actual_score_2 - expected_score_2)
    
    # Sort the ratings by model ID
    sorted_ratings = sorted(ratings.items(), key=lambda x: x[0])
    return [rating for _, rating in sorted_ratings]

elo_ratings = calculate_elo_ratings(coherence_battle_results)
elo_ratings = [(r-1000)/1000 for r in elo_ratings]
elo_ratings

[0.45237873369521686,
 0.4926265310728829,
 0.9215189035708103,
 0.465773255028723,
 0.44427737621564733,
 0.444017195431941,
 0.4646702222194899,
 0.4824440130379091,
 0.4682095926039381,
 0.4386360718597648,
 0.46450690307764514,
 0.4609412021860303]

In [80]:
def get_bootstrap_result(comparisons, num_rounds=1000):
    num_models = max(max(comp['model_1'], comp['model_2']) for comp in comparisons) + 1
    bootstrap_results = []

    for _ in tqdm(range(num_rounds), desc="Bootstrap Progress"):
        # Resample comparisons with replacement
        resampled_comparisons = random.choices(comparisons, k=len(comparisons))
        
        # Calculate Elo ratings for this bootstrap sample
        elo_ratings = calculate_elo_ratings(resampled_comparisons)
        
        # Ensure all model IDs are represented
        full_ratings = [elo_ratings[i] if i < len(elo_ratings) else 1500 for i in range(num_models)]
        
        bootstrap_results.append(full_ratings)

    # Convert results to numpy array for easier calculations
    bootstrap_results = np.array(bootstrap_results)

    # Calculate median and 95% confidence intervals
    median_ratings = np.median(bootstrap_results, axis=0)
    confidence_intervals = np.percentile(bootstrap_results, [2.5, 97.5], axis=0)

    return median_ratings, confidence_intervals

In [81]:
median_ratings, confidence_intervals = get_bootstrap_result(coherence_battle_results)
median_ratings

Bootstrap Progress: 100%|██████████| 1000/1000 [00:02<00:00, 495.89it/s]


array([1451.09239779, 1520.50964692, 1918.28230105, 1445.86011728,
       1461.56578304, 1465.87451713, 1459.10328725, 1439.05424492,
       1459.64142854, 1453.40621462, 1468.08896186, 1458.60210043])

In [82]:
confidence_intervals

array([[1397.79413991, 1471.78512406, 1882.09175726, 1397.03714098,
        1411.07125725, 1420.84975408, 1412.55636532, 1389.8859755 ,
        1407.35035211, 1409.65052804, 1416.888154  , 1411.13125956],
       [1504.0727467 , 1571.58102952, 1954.59568539, 1492.40231009,
        1513.18737448, 1519.98191094, 1507.97572548, 1488.3368589 ,
        1506.0314301 , 1501.76653367, 1518.49340818, 1510.3017773 ]])

In [83]:
df = pd.DataFrame({
    'Scale': scales,
    'Average Scores': avg_scores,
    'Average Coherence': avg_coherence,
    'Scores * Coherence': [a * b for a, b in zip(avg_scores, avg_coherence)],
    'Coherence Elo': elo_ratings,
})

# Create the line plot with multiple traces
fig = px.line(df, x='Scale', y=['Average Scores', 'Average Coherence', 'Scores * Coherence', 'Coherence Elo'],
              title=name)

# Customize the layout if needed
fig.update_layout(
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()

In [84]:
px.line(median_ratings)

In [85]:
def create_comparison_heatmap(comparisons):
    # Create a dictionary to store win counts
    win_counts = defaultdict(lambda: defaultdict(int))
    
    # Count wins for each model pair
    for comp in comparisons:
        model_1 = comp['model_1']
        model_2 = comp['model_2']
        winner = comp['winner']
        
        if not comp['tie']:
            if winner == model_1:
                win_counts[model_1][model_2] += 1
            else:
                win_counts[model_2][model_1] += 1

    # Get unique model IDs
    model_ids = sorted(set(model_id for comp in comparisons for model_id in [comp['model_1'], comp['model_2']]))

    # Create the heatmap data
    heatmap_data = [[win_counts[y][x] for x in model_ids] for y in model_ids]

    # Create the heatmap
    fig = go.Figure(data=go.Heatmap(
        z=heatmap_data,
        x=model_ids,
        y=model_ids,
        colorscale='Viridis',
        hoverongaps=False))

    # Update layout
    fig.update_layout(
        title='Model Comparison Heatmap: Number of Wins',
        xaxis_title='Model ID (Loser)',
        yaxis_title='Model ID (Winner)',
        xaxis=dict(tickmode='linear'),
        yaxis=dict(tickmode='linear'),
        width=800,
        height=800,
    )

    return fig

create_comparison_heatmap(coherence_battle_results)

In [86]:
all_texts[0]

['I think you would find the 203s in the E4L line too soft? My friend has a 20',
 'I think all this is a lie to keep them in the dark.. The truth is he has been using a drug thats bad for him',
 "I think we're gonna have to start looking for a new drummer. I mean, it can't be THAT difficult, we",
 "I think they were really good at the beginning. Since then they've been mediocre to awful. I hate them so much. I",
 'I think that you can do it, just check this website:\n\nhttp://www.sands.com.sg/en/',
 'I think it is fair to say that most of use do not think that the vast majority of young people are violent or unruly. When',
 'I think it’s really important to take pride in what you’re doing, and we’re trying to do an environment where',
 'I think this is a duplicate of #34\n\nWhat is the status here? Is the issue assigned to a developer?\n\n@',
 "I think if you're a beginner, <em>the way</em> of how you learn is probably most important: reading the manual,",
 "I think it'd be neat to put