In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f1c584e3370>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
# intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [18]:
scales = [x for x in range(0, 120)]
prompt = "I think"

name = "anger"

steer = anger

# eval_criterion = "Mentions writing or anything related to writing" # writing
# eval_criterion = "Mentions London or anything related to London" # London
# eval_criterion = "Mentions wedding or anything related to wedding" # wedding
eval_criterion = "Text is angry or mentions anger/frustration or anything related to anger" # anger

coherence_criterion = "Text is coherent, the grammar is correct."

In [6]:
def rate(scale, steering_vector):
    texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=steering_vector, scale=scale))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=128,
        )

    # eval = evaluate_completions(texts, criterion=eval_criterion, prompt=prompt, verbose=False)
    # coherence = evaluate_completions(texts, criterion=coherence_criterion, prompt=prompt, verbose=False)
    eval, coherence = multi_criterion_evaluation(texts,
                                                 [eval_criterion, coherence_criterion],
                                                 prompt=prompt,
                                                 verbose=False,
                                                 )
    scores = [e['score'] for e in eval]
    coherence_scores = [e['score'] for e in coherence]
    return scores, coherence_scores

In [7]:
avg_scores = []
avg_coherence = []
all_scores = []
all_coherence = []

for scale in tqdm(scales):
    scores, coherence = rate(scale, steer)
    avg_scores.append(sum(scores) / len(scores))
    avg_coherence.append(sum(coherence) / len(coherence))
    all_scores.append(scores)
    all_coherence.append(coherence)

# scale to be from 0 to 1
avg_scores = [(x - 1) / 9 for x in avg_scores]
avg_coherence = [(x - 1) / 9 for x in avg_coherence]

# Scale all_scores and all_coherence to be from 0 to 1
all_scores = [[(score - 1) / 9 for score in scale_scores] for scale_scores in all_scores]
all_coherence = [[(coh - 1) / 9 for coh in scale_coherence] for scale_coherence in all_coherence]


  0%|          | 0/120 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [None]:
df = pd.DataFrame({
    'Scale': scales,
    'Average Scores': avg_scores,
    'Average Coherence': avg_coherence,
    'Scores * Coherence': [a * b for a, b in zip(avg_scores, avg_coherence)]
})

# Create the line plot with multiple traces
fig = px.line(df, x='Scale', y=['Average Scores', 'Average Coherence', 'Scores * Coherence'],
              title=name)

# Customize the layout if needed
fig.update_layout(
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()

In [7]:
batch_size = 8
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=32)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=batch_size)

# note that this only works for inserting into all positions.
@torch.no_grad()
def compute_melbo(steering_vector, read_layer=12, p=2, q=2, n_steps=10):
    unsteered_resid = None
    steered_resid = None

    def write_hook(resid, hook):
        resid[:, :, :] = resid[:, :, :] + steering_vector
        return resid
    
    def unsteered_read_hook(resid, hook):
        nonlocal unsteered_resid
        unsteered_resid = resid.clone()

    def steered_read_hook(resid, hook):
        nonlocal steered_resid
        steered_resid = resid.clone()

    total_melbo = 0
    total_loss_diff = 0

    for batch_idx, batch in enumerate(loader):
        with model.hooks(fwd_hooks=[
            (hp6, write_hook),
            (f'blocks.{read_layer}.hook_resid_post', steered_read_hook)
        ]):
            steered_loss = model(batch['tokens'], return_type='loss')

        with model.hooks(fwd_hooks=[
            (f'blocks.{read_layer}.hook_resid_post', unsteered_read_hook)
        ]):
            unsteered_loss = model(batch['tokens'], return_type='loss')
        
        # compute melbo between unsteered_resid and steered_resid
        diffs = torch.norm(unsteered_resid - steered_resid, dim=-1)
        diffs = diffs ** p
        summed_over_tokens = diffs.sum(dim=-1)
        summed_over_tokens = summed_over_tokens ** (1/q)

        total_melbo += summed_over_tokens.sum().item()

        total_loss_diff += (steered_loss - unsteered_loss).sum().item()

        if batch_idx >= n_steps - 1:
            break
    
    return total_melbo / (n_steps * batch_size), total_loss_diff / (n_steps * batch_size)
    

In [None]:
melbos = []
loss_diffs = []
for scale in tqdm(scales):
    melbo, loss_diff = compute_melbo(scale * steer)
    melbos.append(melbo)
    loss_diffs.append(loss_diff)

100%|██████████| 120/120 [02:32<00:00,  1.27s/it]


In [None]:
df = pd.DataFrame({
    'Scale': scales,
    'Average Scores': avg_scores,
    'Average Coherence': avg_coherence,
    'Scores * Coherence': [a * b for a, b in zip(avg_scores, avg_coherence)],
    # 'Melbo': melbos,
    'Loss Diff': loss_diffs,
})

# Create the line plot with multiple traces
fig = px.line(df, x='Scale', y=['Average Scores', 'Average Coherence', 'Scores * Coherence',
                                # 'Melbo',
                                'Loss Diff'],
              title=name)

# Customize the layout if needed
fig.update_layout(
    xaxis_title="Scale",
    yaxis_title="Values",
    legend_title="Metrics"
)

# Show the plot
fig.show()

In [None]:
print(f'melbo at scale 50 for {name} is:')
print(compute_melbo(50 * steer))

melbo at scale 50 for anger is:
(276.6551971435547, -0.041250011324882506)


In [None]:
# hacky plot

# names = ["writing", "anger", "london", "wedding"]

# approx optimal scales computed from peak of score * coherence
# optimal_scales = [55, 62, 80, 70]

# approx optimal scales computed from coherence == 0.6
optimal_scales = [83, 70, 92, 98]

melbo_scores = [245.49, 276.66, 235.48, 231.97]
melbo_scores = [x / 10000 for x in melbo_scores]
loss_values = [0.023031622171401978, 0.041250011324882506, 0.022686100006103514, 0.02259013056755066]

px.scatter(x=optimal_scales, y=melbo_scores, title="Melbo vs optimal scale", labels={'x': 'Optimal Scale', 'y': 'Melbo'})

In [None]:
# fig = px.scatter(x=optimal_scales, y=melbo_scores, title="Melbo vs optimal scale", labels={'x': 'Optimal Scale', 'y': 'Melbo'})
# fig.add_trace(go.Scatter(x=optimal_scales, y=loss_values, mode='markers', marker=dict(color='red'), name='loss_values'))
# fig.update_layout(showlegend=True)
# fig.show()

## Attempt measure effect on layer 12 SAE

In [8]:
# load layer 12 sae. Don't do fancy input norm.

hp12 = "blocks.12.hook_resid_post"

sae12, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp12, # won't always be a hook point
    device = 'cpu'
)

sae12 = sae12.to(device)
normalise_decoder(sae12)

In [19]:
# note that this only works for inserting into all positions.
@torch.no_grad()
def compute_sae_effect(steering_vector, p=2, q=2, n_steps=10):
    unsteered_resid = None
    steered_resid = None

    def write_hook(resid, hook):
        resid[:, :, :] = resid[:, :, :] + steering_vector
        return resid
    
    def unsteered_read_hook(resid, hook):
        nonlocal unsteered_resid
        unsteered_resid = resid.clone()

    def steered_read_hook(resid, hook):
        nonlocal steered_resid
        steered_resid = resid.clone()

    total_sae_diff = 0
    # total_loss_diff = 0

    for batch_idx, batch in enumerate(loader):
        with model.hooks(fwd_hooks=[
            (hp6, write_hook),
            (hp12, steered_read_hook)
        ]):
            model(batch['tokens'], stop_at_layer=13)

        with model.hooks(fwd_hooks=[
            (hp12, unsteered_read_hook)
        ]):
            model(batch['tokens'], stop_at_layer=13)
        
        # compute melbo between unsteered_resid and steered_resid
        # diffs = torch.norm(unsteered_resid - steered_resid, dim=-1)

        unsteered_acts = sae12.encode(unsteered_resid.reshape(-1, unsteered_resid.shape[-1]))
        steered_acts = sae12.encode(steered_resid.reshape(-1, steered_resid.shape[-1]))

        diffs = torch.abs(unsteered_acts - steered_acts)

        diffs = diffs ** p
        summed_over_feats = diffs.sum(dim=-1)
        summed_over_feats = summed_over_feats ** (1/q)
        total_sae_diff += summed_over_feats.sum().item()

        if batch_idx >= n_steps - 1:
            break
    
    return total_sae_diff / (n_steps * batch_size)

In [36]:
sae_diffs = []
for scale in tqdm(scales):
    sae_diff = compute_sae_effect(scale * steer, p=2, q=2)
    sae_diffs.append(sae_diff)

100%|██████████| 120/120 [01:58<00:00,  1.01it/s]


In [37]:
px.line(sae_diffs)

In [38]:
grad_of_diffs = [sae_diffs[i] - sae_diffs[i-1] for i in range(1, len(sae_diffs))]
px.line(grad_of_diffs)

In [40]:
steer_name_idx = [("writing", 1058), ("london", 10138), ("wedding", 8406), ("anger", 1062)]
steer_name_idx

[('writing', 1058), ('london', 10138), ('wedding', 8406), ('anger', 1062)]

In [51]:
all_vec_sae_diffs = []
all_vec_sae_diff_grads = []

for steer_name, steer_idx in steer_name_idx:
    steer = sae6.W_dec[steer_idx]

    sae_diffs = []
    for scale in tqdm(scales):
        sae_diff = compute_sae_effect(scale * steer, p=2, q=2, n_steps=2)
        sae_diffs.append(sae_diff)
    
    all_vec_sae_diffs.append(sae_diffs)
    all_vec_sae_diff_grads.append([sae_diffs[i] - sae_diffs[i-1] for i in range(1, len(sae_diffs))])
    

100%|██████████| 120/120 [00:23<00:00,  5.05it/s]
100%|██████████| 120/120 [00:23<00:00,  5.07it/s]
100%|██████████| 120/120 [00:23<00:00,  5.08it/s]
100%|██████████| 120/120 [00:23<00:00,  5.07it/s]


In [52]:
def plot_sae_diff_grads(scales, all_vec_sae_diff_grads, steer_name_idx):
    fig = go.Figure()

    colors = ['blue', 'red', 'green', 'purple']  # Add more colors if needed

    for i, sae_diff_grads in enumerate(all_vec_sae_diff_grads):
        steer_name = steer_name_idx[i][0]
        
        # Plot gradients
        fig.add_trace(
            go.Scatter(x=scales[1:], y=sae_diff_grads, mode='lines', 
                       name=f'{steer_name}',
                       line=dict(color=colors[i]))
        )

    # Update layout
    fig.update_layout(
        height=600, 
        width=800, 
        title_text="Gradients of SAE Differences for Different Steering Vectors",
        xaxis_title="Scale",
        yaxis_title="Gradient of SAE Difference",
        legend_title="Steering Vectors"
    )

    return fig

# Usage
fig = plot_sae_diff_grads(scales, all_vec_sae_diff_grads, steer_name_idx)
fig.show()