In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

import einops

from sae_lens import SAE

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid


import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

from IPython.display import HTML, display
import html


import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f844407e7d0>

In [2]:
os.environ['GEMMA_2_SAE_WEIGHTS_ROOT'] = '/workspace/gemmasaes'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: HookedTransformer = HookedTransformer.from_pretrained("google/gemma-2-9b-it", device=device, dtype=torch.float16)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [4]:
model.tokenizer.special_tokens_map

{'bos_token': '<bos>',
 'eos_token': '<eos>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}

In [5]:
chat = [[
    {"role": "user", "content": "Write a Hello world program"},
]]
tokens = model.tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
tokens

tensor([[    2,   106,  1645,   108,  5559,   476, 25957,  2134,  2733,   107,
           108,   106,  2516,   108]])

In [6]:
gen_toks = model.generate(tokens, max_new_tokens=20)
model.to_string(gen_toks)

  0%|          | 0/20 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nWrite a Hello world program<end_of_turn>\n<start_of_turn>model\n```python\nprint("Hello, world!")\n```\n\n**Explanation:**\n\n* **`print']

In [7]:
SEQ_LEN = 128 - 4 + 1 # 4 for the special tokens at the start, 1 because we'll cut bos.

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
# print(data)
# assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

torch.Size([222989, 125])


In [8]:
pre_tokens = model.tokenizer.apply_chat_template([{"role": "user", "content": ""}], return_tensors='pt')[0][:4]
print(pre_tokens.shape)
print(model.to_str_tokens(pre_tokens))
pre_tokens

torch.Size([4])
['<bos>', '<start_of_turn>', 'user', '\n']


tensor([   2,  106, 1645,  108])

In [9]:
n_samples = all_tokens.shape[0]
ready_tokens = torch.cat([pre_tokens.expand(n_samples, -1), all_tokens[:, 1:]], dim=1)
ready_tokens.shape

torch.Size([222989, 128])

In [10]:
print(model.to_str_tokens(ready_tokens[4]))

['<bos>', '<start_of_turn>', 'user', '\n', ' instance', ' of', ' Transition', ' class', '.', '\n\n', '    ', 'Args', ':', '\n', '      ', 'sess', ' (:', 'obj', ':`', 'tf', '.', 'Session', '`', '):', ' This', ' attribute', ' represents', ' the', ' session', ' that', ' runs', '\n', '        ', 'the', ' TensorFlow', ' operations', '.', '\n', '      ', 'name', ' (', 'str', '):', ' This', ' attribute', ' represents', ' the', ' name', ' of', ' the', ' object', ' in', '\n', '        ', 'Tensor', 'Flow', "'", 's', ' op', ' Graph', '.', '\n', '      ', 'graph', ' (:', 'obj', ':`', 'tf', 'graph', '.', 'Graph', '`', '):', '  ', 'The', ' graph', ' on', ' which', ' the', ' transition', ' is', ' referred', '.', '\n', '      ', 'beta', ' (', 'float', '):', ' The', ' reset', ' probability', ' of', ' the', ' random', ' walks', ',', ' i', '.', 'e', '.', ' the', '\n', '        ', 'probability', ' that', ' a', ' user', ' that', ' sur', 'fs', ' the', ' graph', ' an', ' decides', ' to', ' jump', ' to', '\n'

In [11]:
# sae_layer = 9
# hp = f"blocks.{sae_layer}.hook_resid_post"
# sae, _, _ = SAE.from_pretrained(
#     release = "gemma-2-it-saes",
#     sae_id = "9/post_mlp_residual/16384/0_00045",
#     device = 'cpu'
# )
# sae.to(device)
sae_layer = 20
hp = f"blocks.{sae_layer}.hook_resid_post"
sae, _, _ = SAE.from_pretrained(
    release = "gemma-2-it-saes",
    sae_id = "20/post_mlp_residual/16384/0_00045",
    device = 'cpu'
)
sae.to(device)

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [12]:
toks = model.tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, who are you?"},
        {"role": "model", "content": "I am the Golden Gate Bridge"},
        ],
    return_tensors='pt')[0]

# text = "I traveled to London"
# toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt')[0]
print(model.to_str_tokens(toks))
toks = toks.to(device)

_, acts = model.run_with_cache(toks, names_filter=hp)
acts = acts[hp]

all_sae_acts = []
for batch in acts:
    sae_acts = sae.encode(batch)
    all_sae_acts.append(sae_acts)

all_sae_acts = torch.stack(all_sae_acts, dim=0)

['<bos>', '<start_of_turn>', 'user', '\n', 'Hello', ',', ' who', ' are', ' you', '?', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n', 'I', ' am', ' the', ' Golden', ' Gate', ' Bridge', '<end_of_turn>', '\n']


In [13]:
top_v, top_i = torch.topk(all_sae_acts, 10, dim=-1)
top_i

tensor([[[ 6802,  2122, 11299,  7697,   589, 14682,  4016,  5827, 10002,  1412],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444,  3892,  2122,  7806],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444,  7806,  3892, 11271],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444, 11271,  2122, 10002],
         [12507,  5575,  6802,  4230,  7063, 14213, 12571, 11819, 11692,  5375],
         [12507,  7063,  5575,  7155,  7933,  7567,  6802, 14213,  4230,  8029],
         [ 6876, 12650,  1170, 14213, 14031,  6802, 12507, 11488, 13508,  5906],
         [ 6876,  5906,  8641,  8580,  2431,  6802,  6532, 12650,  1054, 13508],
         [ 6802, 11488, 13508, 14213,  8260,  6876,  3442,  2431,  5906, 13906],
         [ 6802, 13508, 12021,  5906,  1054,  9918,  3442, 11488,  9422,  2431],
         [ 6802,  8163,  8852, 12507,  2738,  5454,  8116, 15063, 10427,  8766],
         [ 6802, 11488, 14213, 13508,  5454,  8769, 13658, 15571, 12021,   436],
         [ 6802, 11488, 1421

In [14]:
def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    # tokens is shape (n_samples, seq_len)

    for i in range(0, tokens.shape[0], batch_size):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]

        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        print(max_sae_acts.shape) # (batch_size, n_features)
        break



# get_top_samples(ready_tokens, potential_london_ids)



    

In [15]:
from heapq import heappush, heappushpop

def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    n_features = None  # We'll set this when we get the first batch of results
    
    # Initialize lists to store the top k values and indices for each feature
    top_k_values = []
    top_k_indices = []
    
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        
        if n_features is None:
            n_features = max_sae_acts.shape[1]
            top_k_values = [[] for _ in range(n_features)]
            top_k_indices = [[] for _ in range(n_features)]
        
        # Update top k for each feature
        for feature_idx in range(n_features):
            feature_values = max_sae_acts[:, feature_idx]
            for sample_idx, value in enumerate(feature_values):
                global_idx = i + sample_idx
                if len(top_k_values[feature_idx]) < top_k:
                    heappush(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
                elif value > top_k_values[feature_idx][0][0]:
                    heappushpop(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
    
    # Sort the results for each feature
    for feature_idx in range(n_features):
        sorted_results = sorted(zip(top_k_values[feature_idx], top_k_indices[feature_idx]), reverse=True)
        top_k_values[feature_idx] = [value for value, _ in sorted_results]
        top_k_indices[feature_idx] = [idx for _, idx in sorted_results]
    
    return top_k_values, top_k_indices

# # Usage
# top_k_values, top_k_indices = get_top_samples(ready_tokens[:10000], top_k=10, batch_size=16)


  0%|          | 0/625 [00:00<?, ?it/s]

  0%|          | 0/625 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [16]:
#### layer 9 ####
# # save top_k_values and top_k_indices
# with open("top_k_values.json", "w") as f:
#     json.dump(top_k_values, f)
# with open("top_k_indices.json", "w") as f:
#     json.dump(top_k_indices, f)
# # save tokens
# torch.save(ready_tokens, "ready_tokens.pt")

# # load top_k_values and top_k_indices
# with open("top_k_values.json", "r") as f:
#     top_k_values = json.load(f)
# with open("top_k_indices.json", "r") as f:
#     top_k_indices = json.load(f)
# # load tokens
# ready_tokens = torch.load("ready_tokens.pt")


# #### layer 20 ####
# # save top_k_values and top_k_indices
# with open("top_k_values_20.json", "w") as f:
#     json.dump(top_k_values, f)
# with open("top_k_indices_20.json", "w") as f:
#     json.dump(top_k_indices, f)
# # save tokens
# torch.save(ready_tokens, "ready_tokens_20.pt")

# load top_k_values and top_k_indices
with open("top_k_values_20.json", "r") as f:
    top_k_values = json.load(f)
with open("top_k_indices_20.json", "r") as f:
    top_k_indices = json.load(f)
# load tokens
ready_tokens = torch.load("ready_tokens_20.pt")



  ready_tokens = torch.load("ready_tokens_20.pt")


In [17]:
def acts_for_feature(tokens, feature_idx: int, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    all_acts = []
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        batch_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            batch_sae_acts.append(sae_acts[:, feature_idx])
        
        batch_sae_acts = torch.stack(batch_sae_acts, dim=0)
        all_acts.append(batch_sae_acts)
    
    return torch.cat(all_acts, dim=0)

acts_for_selected = acts_for_feature(ready_tokens[:10], 2)
acts_for_selected

100%|██████████| 3/3 [00:00<00:00, 10.38it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)

In [18]:
def display_highlighted_tokens(tokens, values, tokenizer):
    # Ensure tensors are on CPU
    tokens = tokens.cpu()
    values = values.cpu().to(torch.float32)
    
    original_values = values.clone()
    
    # Normalize values to range [0, 1] for coloring
    values = (values - values.min()) / (values.max() - values.min())
    
    # Function to convert value to RGB color
    def value_to_color(value):
        # Using a gradient from white to red
        color_value = int(255 * (1 - value.item()))
        return f"rgb(255, {color_value}, {color_value})"
    
    html_output = ""
    
    for batch_idx in range(tokens.shape[0]):
        # Use the original values for max activation
        max_activation = original_values[batch_idx].max().item()
        html_output += f"<p>Max Activation: {max_activation:.4f}:</p>"
        html_output += "<p>"
        
        for token, value in zip(tokens[batch_idx], values[batch_idx]):
            word = tokenizer.decode([token.item()])
            # Escape special characters to prevent HTML conflicts
            word_escaped = html.escape(word)
            color = value_to_color(value)
            html_output += f'<span style="background-color: {color};">{word_escaped}</span>'
        
        html_output += "</p>"
    
    display(HTML(html_output))

In [19]:
london = sae.W_dec[4457]
britain = sae.W_dec[10423]

brit_20 = sae.W_dec[5771]
capital_20 = sae.W_dec[5687]

# anti_im_20 = sae.W_dec[4580]
# problem_20 = sae.W_dec[9337]

In [20]:
# ft_id = 4457
# ft_id = 10423

# layer 20
ft_id = 5771 # brit
# ft_id = 5687 # capital

# ft_id = 9337
# ft_id = 11817

ft_id = 4669



samples = ready_tokens[top_k_indices[ft_id]]

acts = acts_for_feature(samples, ft_id)
acts

display_highlighted_tokens(samples[:, 3:], acts[:, 3:], model.tokenizer)

100%|██████████| 3/3 [00:00<00:00, 11.08it/s]


In [27]:
# steer model
# text = "What do you think about immigration?"
# text = "Tell me a fun fact."
# text = "Tell me a fact about Australia"
# text = "Where should I travel to?"
# text = "Tell me a story about travel"
# text = "What should I do in Paris? I'm interested in history."
# text = "Pick a random city."
# text = "If you could go anywhere in the world, where would you go?"
text = "What should I do today?"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt', add_generation_prompt=True)
toks = toks.expand(10, -1)
print(toks.shape)
print(model.to_str_tokens(toks[0]))
toks = toks.to(device)


torch.Size([10, 15])
['<bos>', '<start_of_turn>', 'user', '\n', 'What', ' should', ' I', ' do', ' today', '?', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n']


In [28]:
gen_toks = model.generate(toks, max_new_tokens=40)
model.to_string(gen_toks)

  0%|          | 0/40 [00:00<?, ?it/s]

["<bos><start_of_turn>user\nWhat should I do today?<end_of_turn>\n<start_of_turn>model\nAs an AI, I don't know your personal preferences or what's available to you. To give you a good suggestion, I need more information!  \n\nTell me:\n\n* **",
 "<bos><start_of_turn>user\nWhat should I do today?<end_of_turn>\n<start_of_turn>model\nAs an AI, I don't know anything about you, your interests, or the weather outside! To give you a good suggestion, I need a little more information.  \n\nTell me:",
 "<bos><start_of_turn>user\nWhat should I do today?<end_of_turn>\n<start_of_turn>model\nAs a large language model, I can't know what you'd enjoy doing! To give you a good suggestion, tell me:\n\n* **What kind of mood are you in?**",
 '<bos><start_of_turn>user\nWhat should I do today?<end_of_turn>\n<start_of_turn>model\nTo give you the best suggestion for what to do today, I need a bit more information! Tell me:\n\n* **What kind of day are you looking for?** \n    * Relax',
 '<bos><start_of_turn>use

In [23]:
def patch_end(resid, hook, steering, scale=1):
    resid[:, -1, :] = resid[:, -1, :] + steering * scale
    return resid

In [24]:
with model.hooks([(hp, partial(patch_end, steering=brit_20, scale=150))]):
    gen_toks = model.generate(toks, max_new_tokens=50)
model.to_string(gen_toks)

  0%|          | 0/50 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nPick a random city.<end_of_turn>\n<start_of_turn>model\nSheffield.<end_of_turn>\n<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>',
 '<bos><start_of_turn>user\nPick a random city.<end_of_turn>\n<start_of_turn>model\nStockholm!  🇬🇧<end_of_turn>\n<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>',
 '<bos><start_of_turn>user\nPick a random city.<end_of_turn>\n<start_of_turn>model\nLeeds.  <end_of_turn>\n<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>',
 '<bos><start_of_turn>user\nPick a random city.<end_of_turn>\n<start_of_turn>model\nLeicester.<end_of_turn><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>',
 '<bos><start_of_turn>user\nPick a random city.<end_of_turn>\n<start_of_turn>model\nPorto. 🇬🇧<end_of_turn><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><