In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f772c2432b0>

In [2]:

def patch_resid(resid, hook, steering, scale=1):
    resid[:, :, :] = resid[:, :, :] + steering * scale
    # resid[:, -1, :] = resid[:, -1, :] + steering * scale
    # resid[:, :-1, :] = resid[:, :-1, :] + steering * scale


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: HookedTransformer = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [4]:
hp6 = "blocks.6.hook_resid_post"
sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu',
)
sae6 = sae6.to(device)
normalise_decoder(sae6)

In [5]:
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [6]:
r1 = torch.randn_like(writing)
r1 = r1 / r1.norm()
r2 = torch.randn_like(writing)
r2 = r2 / r2.norm()

In [7]:
rs = [torch.randn_like(writing) for _ in range(5)]
rs = [r / r.norm() for r in rs]

In [20]:
# text = "I think she won't say what's on her mind."
# text = "x x x x x x x x x x x x a"
text = "His height was 6'2"
toks = model.to_tokens(text)

In [21]:
@torch.no_grad()
def perplexity(tokens):
    losses = model(tokens, return_type="loss", loss_per_token=True)
    return torch.exp(losses).cpu()


perp = perplexity(toks)[0]

In [22]:
@torch.no_grad()
def get_norms(tokens, layers):
    norms = []
    logits, cache = model.run_with_cache(tokens)

    for layer in layers:
        norms_for_layer = cache[f"blocks.{layer}.hook_resid_post"].norm(dim=-1)[0]
        norms.append([n.item() for n in norms_for_layer])
    
    return norms


get_norms(toks, [6, 12])

[[739.2119750976562,
  199.86607360839844,
  188.90658569335938,
  151.79449462890625,
  153.72552490234375,
  152.05076599121094,
  188.39154052734375,
  191.88259887695312],
 [1469.5340576171875,
  259.186767578125,
  269.83160400390625,
  245.91485595703125,
  218.47584533691406,
  221.83963012695312,
  241.31765747070312,
  253.18353271484375]]

In [23]:
token_strings = model.to_str_tokens(toks[0])
token_strings = [f"{t}/{i}" for i, t in enumerate(token_strings)]
token_strings = token_strings[1:]
print(token_strings)
print(len(token_strings))
print(perp)
print(perp.shape)

['His/1', ' height/2', ' was/3', ' /4', '6/5', "'/6", '2/7']
7
tensor([8.9837e+03, 2.9875e+03, 1.1156e+01, 3.5539e+00, 3.5585e+00, 3.4222e+00,
        8.9971e+00])
torch.Size([7])


In [24]:
df = pd.DataFrame({
    'Token': token_strings,
    'Perplexity': perp.numpy()
})
fig = px.line(df, x='Token', y='Perplexity', markers=True, log_y=True)
fig.update_layout(
    title='Perplexity per Token',
    xaxis_title='Token',
    yaxis_title='Perplexity',
    xaxis_tickangle=-45
)
fig.show()

In [25]:
# Assume we have already defined and run get_norms
norms = get_norms(toks, [3, 6, 10, 12, 15])
norms = [ns[1:] for ns in norms]

# Prepare token strings
token_strings = model.to_str_tokens(toks[0])
token_strings = [f"{t}/{i}" for i, t in enumerate(token_strings)]
token_strings = token_strings[1:]

# Create DataFrame
df = pd.DataFrame({
    'Token': token_strings,
    'Layer 3 Norm': norms[0],
    'Layer 6 Norm': norms[1],
    'Layer 10 Norm': norms[2],
    'Layer 12 Norm': norms[3],
    'Layer 15 Norm': norms[4]
})

# Create the figure
fig = go.Figure()

# Add traces for each layer
layers = [6, 10, 12, 15]
colors = ['red', 'blue', 'green', 'purple', 'orange']  # You can adjust these colors as needed

for layer, color in zip(layers, colors):
    fig.add_trace(go.Scatter(
        x=df['Token'], 
        y=df[f'Layer {layer} Norm'], 
        mode='lines+markers', 
        name=f'Layer {layer}',
        line=dict(color=color)
    ))

# Calculate the maximum y value
max_y = max(df[[f'Layer {layer} Norm' for layer in layers]].max())

# Update layout
fig.update_layout(
    title='Norms per Token for Layers 3, 6, 10, 12, and 15',
    xaxis_title='Token',
    yaxis_title='Norm',
    xaxis_tickangle=-45,
    yaxis=dict(range=[0, max_y * 1.1]),  # Force y-axis to start at 0
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)

# Show the plot
fig.show()

### Logit lens at apostrophe token

In [14]:
text = "I think she won'"
toks = model.to_tokens(text)

In [15]:
print(model.to_str_tokens(toks[0]))


['<bos>', 'I', ' think', ' she', ' won', "'"]


In [16]:
def run_model_and_collect_logits(model, toks, layer_names, steering=None, scale=None, layer=6):
    correct_logits = []
    incorrect_logits = []
    ldiffs = []
    
    hooks = []
    if steering is not None and scale is not None:
        hp = f"blocks.{layer}.hook_resid_post"
        hooks.append((hp, partial(patch_resid, steering=steering, scale=scale)))
    
    with model.hooks(hooks):
        logits, cache = model.run_with_cache(toks, return_type="logits", names_filter=layer_names)
    
    correct_final_id = model.to_tokens("t", prepend_bos=False)[0][0].item()
    incorrect_final_id = model.to_tokens("a", prepend_bos=False)[0][0].item()
    
    for layer_name in layer_names:
        res = cache[layer_name]
        # res = model.ln_final(res)
        res = res @ model.W_U
        res = res[0, -1, :]
        correct_logit = res[correct_final_id].item()
        incorrect_logit = res[incorrect_final_id].item()
        correct_logits.append(correct_logit)
        incorrect_logits.append(incorrect_logit)
        ldiffs.append(correct_logit - incorrect_logit)
    
    return correct_logits, incorrect_logits, ldiffs

In [17]:
# layer_names = [f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]
layer_names = [f"blocks.{i}.hook_mlp_out" for i in range(model.cfg.n_layers)]
# layer_names = [f"blocks.{i}.hook_attn_out" for i in range(model.cfg.n_layers)]
correct_logits_orig, incorrect_logits_orig, ldiffs_orig = run_model_and_collect_logits(model, toks, layer_names)
correct_logits_steer, incorrect_logits_steer, ldiffs_steer = run_model_and_collect_logits(model, toks, layer_names, steering=london, scale=100)


In [18]:
fig = px.line(ldiffs_orig, title="logit diff at layers")
fig.add_trace(go.Scatter(y=ldiffs_steer, mode='lines', name='Steer'))
fig.show()

In [19]:
# head dla
@torch.no_grad()
def get_head_ldiffs(model, toks, steering=None, scale=None, layer=6):
    """Returns logit diffs for each layer for each head"""
    hooks = []
    if steering is not None and scale is not None:
        hp = f"blocks.{layer}.hook_resid_post"
        hooks.append((hp, partial(patch_resid, steering=steering, scale=scale)))

    diffs = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
    with model.hooks(hooks):
        logits, cache = model.run_with_cache(toks, return_type="logits")

    correct_final_id = model.to_tokens("t", prepend_bos=False)[0][0].item()
    incorrect_final_id = model.to_tokens("a", prepend_bos=False)[0][0].item()

    for i in range(model.cfg.n_layers):
        z = cache[f"blocks.{i}.attn.hook_z"] # (batch, pos, heads, head_dim)
        out = einops.einsum(z, model.blocks[i].attn.W_O,
                            "batch pos heads head_dim, heads head_dim model_dim -> batch pos heads model_dim")
        out = out[0, -1, :, :]
        res = einops.einsum(out, model.W_U, "heads model_dim, model_dim vocab -> heads vocab")
        correct_logits = res[:, correct_final_id]
        incorrect_logits = res[:, incorrect_final_id]

        for head in range(model.cfg.n_heads):
            diffs[i, head] = correct_logits[head].item() - incorrect_logits[head].item()

    return diffs


raw_diffs = get_head_ldiffs(model, toks)
mod_diffs = get_head_ldiffs(model, toks, steering=london, scale=100)
diff_diffs = mod_diffs - raw_diffs

In [20]:
fig = make_subplots(rows=1, cols=3, subplot_titles=('raw', 'mod', 'diff'))

for i, data in enumerate([raw_diffs, mod_diffs, diff_diffs], 1):
    fig.add_trace(go.Heatmap(z=data, colorscale="RdBu", zmid=0), row=1, col=i)
    fig.update_xaxes(title_text="Head", row=1, col=i)

fig.update_layout(title_text="Comparison of ldiffs and ddiffs", height=500, width=800)
fig.update_yaxes(title_text="Layer", row=1, col=1)

fig.show()

In [21]:
model.blocks[0].attn.W_O.shape

torch.Size([8, 256, 2048])

In [22]:
# Create plot
fig = go.Figure()
# Add traces for original model
fig.add_trace(go.Scatter(y=correct_logits_orig, mode='lines', name='Original Correct Logit'))
fig.add_trace(go.Scatter(y=incorrect_logits_orig, mode='lines', name='Original Incorrect Logit'))
fig.add_trace(go.Scatter(y=ldiffs_orig, mode='lines', name='Original Logit Diff'))
# Add traces for steered model
fig.add_trace(go.Scatter(y=correct_logits_steer, mode='lines', name='Steered Correct Logit'))
fig.add_trace(go.Scatter(y=incorrect_logits_steer, mode='lines', name='Steered Incorrect Logit'))
fig.add_trace(go.Scatter(y=ldiffs_steer, mode='lines', name='Steered Logit Diff'))
# Update layout
fig.update_layout(
    title="Logits and Logit Differences Across Layers",
    xaxis_title="Layer",
    yaxis_title="Logit Value",
    legend_title="Metrics"
)
fig.show()

In [23]:
correct_logits_orig, incorrect_logits_orig, ldiffs_orig = run_model_and_collect_logits(model, toks, layer_names)
correct_logits_steer, incorrect_logits_steer, ldiffs_steer = run_model_and_collect_logits(model, toks, layer_names, steering=london, scale=200, layer=15)
# Create plot
fig = go.Figure()
# Add traces for original model
fig.add_trace(go.Scatter(y=correct_logits_orig, mode='lines', name='Original Correct Logit'))
# fig.add_trace(go.Scatter(y=incorrect_logits_orig, mode='lines', name='Original Incorrect Logit'))
fig.add_trace(go.Scatter(y=ldiffs_orig, mode='lines', name='Original Logit Diff'))
# Add traces for steered model
fig.add_trace(go.Scatter(y=correct_logits_steer, mode='lines', name='Steered Correct Logit'))
# fig.add_trace(go.Scatter(y=incorrect_logits_steer, mode='lines', name='Steered Incorrect Logit'))
fig.add_trace(go.Scatter(y=ldiffs_steer, mode='lines', name='Steered Logit Diff'))
# Update layout
fig.update_layout(
    title="Logits and Logit Differences Across Layers",
    xaxis_title="Layer",
    yaxis_title="Logit Value",
    legend_title="Metrics"
)
fig.show()

In [24]:
prompts = ["If she won'"]

prompt = "If she won'"
prompt_tokens = model.to_tokens(prompts, prepend_bos=True)
# correct_ids = model.to_tokens(["re", 't', 's', 's', 't'], prepend_bos=False)
correct_ids = model.to_tokens(['t'], prepend_bos=False)

In [25]:
@torch.no_grad()
def many_ft_eval(ft_ids=None, scale=80, correct_ids=correct_ids, use_logits=False, vecs=None):
    res_probs = []
    assert len(prompts) == 1
    batch_size = 32
    batch_in = model.to_tokens(prompts, prepend_bos=True)
    n = len(ft_ids) if vecs is None else len(vecs)
    for i in range(0, n, batch_size):
        if vecs is None:
            batch_ft_ids = ft_ids[i:i+batch_size]
            st = sae6.W_dec[batch_ft_ids]
        else:
            st = torch.stack(vecs[i:i+batch_size])
            
        st = st[:, None, :] ##
        st = st.expand(-1, batch_in.shape[1], -1) ##

        batch_in = batch_in.expand(st.shape[0], -1)
        # with model.hooks([(hp6, partial(patch_final, steering=st, scale=scale))]):
        with model.hooks([(hp6, partial(patch_resid, steering=st, scale=scale))]):
            logits = model(batch_in)[:, -1, :]
            if use_logits:
                probs = logits
            else:
                probs = F.softmax(logits, dim=-1)
            probs = probs[torch.arange(probs.shape[0]), correct_ids.squeeze()]
            for p in probs:
                res_probs.append(p.item())
    return res_probs

# many_ft_eval([10138, 8406, 2378], scale=80)
many_ft_eval(vecs=[london, wedding, broad_wedding], scale=80)

[0.0018133004195988178, 0.08583137392997742, 0.9828453660011292]

In [26]:
all_ft_evals = many_ft_eval(list(range(sae6.cfg.d_sae)), scale=70) # 70
p_ss = many_ft_eval(list(range(sae6.cfg.d_sae)), scale=70, correct_ids=model.to_tokens('s', prepend_bos=False))

KeyboardInterrupt: 

In [None]:
sorted_evals = sorted(enumerate(all_ft_evals), key=lambda x: x[1], reverse=True)

In [None]:
window_size = 100
pss_sorted = [p_ss[i] for i, _ in sorted_evals]
pss_sorted = [sum(pss_sorted[i:i+window_size])/ window_size for i in range(len(pss_sorted) - window_size)]
pss_sorted.extend([pss_sorted[-1]] * window_size)

In [None]:
# Create a DataFrame with both sets of data
df = pd.DataFrame({
    'index': range(len(sorted_evals)),
    'line1': [x[1] for x in sorted_evals],
    'line2': pss_sorted,
})

# Create the line plot with both lines
fig = px.line(df, x='index', y=['line1', 'line2'], title='Prob sorted by p(t)')

# Update the legend labels if needed
fig.update_traces(name='p (t)', selector=dict(name='line1'))
fig.update_traces(name='p (s)', selector=dict(name='line2'))

# Show the plot
fig.show()

In [None]:
# check cosine sim between features?

# TODO: plot logit instead of prob

In [None]:
# TODO: check inserting random vector

In [None]:
with model.hooks([(hp6, partial(patch_resid, steering=r2, scale=80))]):
    tutils.test_prompt("If she won'", "t", model, prepend_space_to_answer=False)
    # tutils.test_prompt("I think you'", "s", model, prepend_space_to_answer=False)

Tokenized prompt: ['<bos>', 'If', ' she', ' won', "'"]
Tokenized answer: ['t']


Top 0th token. Logit: 20.80 Prob: 32.31% Token: |t|
Top 1th token. Logit: 20.14 Prob: 16.66% Token: | last|
Top 2th token. Logit: 20.13 Prob: 16.65% Token: |s|
Top 3th token. Logit: 20.02 Prob: 14.91% Token: | previous|
Top 4th token. Logit: 18.87 Prob:  4.72% Token: |ll|
Top 5th token. Logit: 18.42 Prob:  3.01% Token: |last|
Top 6th token. Logit: 17.86 Prob:  1.71% Token: |  |
Top 7th token. Logit: 17.51 Prob:  1.20% Token: |e|
Top 8th token. Logit: 17.16 Prob:  0.85% Token: |previous|
Top 9th token. Logit: 16.32 Prob:  0.37% Token: | Previous|


In [None]:
scales = list(range(120))
r_t_evals = []
r_s_evals = []
for scale in tqdm(scales):
    r_t_evals.append(many_ft_eval(vecs=rs, scale=scale, correct_ids=model.to_tokens('t', prepend_bos=False)))
    r_s_evals.append(many_ft_eval(vecs=rs, scale=scale, correct_ids=model.to_tokens('s', prepend_bos=False)))

r_t_evals = list(map(list, zip(*r_t_evals)))
r_s_evals = list(map(list, zip(*r_s_evals)))

100%|██████████| 120/120 [00:19<00:00,  6.31it/s]


In [None]:
data = []
for i in range(len(r_t_evals)):
    for eval_type, eval_list in [('eval', r_t_evals), ('seval', r_s_evals)]:
        for scale, eval_value in zip(scales, eval_list[i]):
            data.append({
                'Scale': scale,
                'Evaluation': eval_value,
                'Dataset': f"{i}",
                'Type': eval_type
            })
df = pd.DataFrame(data)
colors = px.colors.qualitative.Plotly[:len(r_t_evals)]
color_map = {str(i): color for i, color in enumerate(colors)}
fig = px.line(df, x='Scale', y='Evaluation', color='Dataset', line_dash='Type',
              title='Random vectors',
              labels={'Scale': 'Scale', 'Evaluation': 'p correct'},
              color_discrete_map=color_map)
fig.update_layout(
    legend_title_text='Dataset',
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    )
)
fig.show()