In [53]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

import einops

from sae_lens import SAE

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid


import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

from IPython.display import HTML, display
import html


import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7efb13a48a60>

In [2]:
os.environ['GEMMA_2_SAE_WEIGHTS_ROOT'] = '/workspace/gemmasaes'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: HookedTransformer = HookedTransformer.from_pretrained("google/gemma-2-9b-it", device=device, dtype=torch.float16)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [4]:
model.tokenizer.special_tokens_map

{'bos_token': '<bos>',
 'eos_token': '<eos>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}

In [5]:
chat = [[
    {"role": "user", "content": "Write a Hello world program"},
]]
tokens = model.tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
tokens

tensor([[    2,   106,  1645,   108,  5559,   476, 25957,  2134,  2733,   107,
           108,   106,  2516,   108]])

In [6]:
gen_toks = model.generate(tokens, max_new_tokens=20)
model.to_string(gen_toks)

  0%|          | 0/20 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nWrite a Hello world program<end_of_turn>\n<start_of_turn>model\nHere\'s a "Hello world" program in several popular programming languages:\n\n**Python:**\n\n']

In [7]:
SEQ_LEN = 128 - 4 + 1 # 4 for the special tokens at the start, 1 because we'll cut bos.

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
# print(data)
# assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

Using the latest cached version of the dataset since NeelNanda/c4-code-20k couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/NeelNanda___c4-code-20k/default/0.0.0/82e9178fdd9e1f7ac93f81234aeedceec49bc8b4 (last modified on Fri Jul 26 23:10:26 2024).


torch.Size([222989, 125])


In [8]:
pre_tokens = model.tokenizer.apply_chat_template([{"role": "user", "content": ""}], return_tensors='pt')[0][:4]
print(pre_tokens.shape)
print(model.to_str_tokens(pre_tokens))
pre_tokens

torch.Size([4])
['<bos>', '<start_of_turn>', 'user', '\n']


tensor([   2,  106, 1645,  108])

In [9]:
n_samples = all_tokens.shape[0]
ready_tokens = torch.cat([pre_tokens.expand(n_samples, -1), all_tokens[:, 1:]], dim=1)
ready_tokens.shape

torch.Size([222989, 128])

In [10]:
print(model.to_str_tokens(ready_tokens[4]))

['<bos>', '<start_of_turn>', 'user', '\n', ' instance', ' of', ' Transition', ' class', '.', '\n\n', '    ', 'Args', ':', '\n', '      ', 'sess', ' (:', 'obj', ':`', 'tf', '.', 'Session', '`', '):', ' This', ' attribute', ' represents', ' the', ' session', ' that', ' runs', '\n', '        ', 'the', ' TensorFlow', ' operations', '.', '\n', '      ', 'name', ' (', 'str', '):', ' This', ' attribute', ' represents', ' the', ' name', ' of', ' the', ' object', ' in', '\n', '        ', 'Tensor', 'Flow', "'", 's', ' op', ' Graph', '.', '\n', '      ', 'graph', ' (:', 'obj', ':`', 'tf', 'graph', '.', 'Graph', '`', '):', '  ', 'The', ' graph', ' on', ' which', ' the', ' transition', ' is', ' referred', '.', '\n', '      ', 'beta', ' (', 'float', '):', ' The', ' reset', ' probability', ' of', ' the', ' random', ' walks', ',', ' i', '.', 'e', '.', ' the', '\n', '        ', 'probability', ' that', ' a', ' user', ' that', ' sur', 'fs', ' the', ' graph', ' an', ' decides', ' to', ' jump', ' to', '\n'

In [11]:
sae_layer = 9
hp = f"blocks.{sae_layer}.hook_resid_post"
sae, _, _ = SAE.from_pretrained(
    release = "gemma-2-it-saes",
    sae_id = "9/post_mlp_residual/16384/0_00045",
    device = 'cpu'
)
sae.to(device)

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [12]:
text = "I traveled to London"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt')[0]
print(model.to_str_tokens(toks))
toks = toks.to(device)

_, acts = model.run_with_cache(toks, names_filter=hp)
acts = acts[hp]

all_sae_acts = []
for batch in acts:
    sae_acts = sae.encode(batch)
    all_sae_acts.append(sae_acts)

all_sae_acts = torch.stack(all_sae_acts, dim=0)

['<bos>', '<start_of_turn>', 'user', '\n', 'I', ' traveled', ' to', ' London', '<end_of_turn>', '\n']


In [13]:
top_v, top_i = torch.topk(all_sae_acts, 10, dim=-1)
top_i

tensor([[[12085, 12324,    35, 11637, 11355,  8984,   466,  7469,  2832, 10547],
         [12085, 11355,  5600, 15460,  1578,  9540,  8961,    35, 10531,  9385],
         [12085,  4839, 15256, 12891,  8968, 15460, 11264,  9540, 10720,  6415],
         [12085,   195, 15256, 15460,  9540,  3767,  8319,  1217, 10720,  2088],
         [12085, 11306,  2725, 15460,  4660,  9540,  9621, 10720,  1899,  8489],
         [ 8209, 12085,  9540, 10241, 12532,  8489, 11608,  3326,  8379, 15452],
         [ 8209,  7678, 12085,  9540,  5303,  8489, 12532, 10241, 11608,  8941],
         [12085,  4457, 10423, 15365,  9540,  3283, 12267,  8153, 11304, 15626],
         [12085,  5600,  3879,  9540, 12532,  2459, 10241, 16219, 14263, 11857],
         [ 9142,  9540, 12085, 16219,  2459, 12532, 16135, 15743, 10241,  8209]]],
       device='cuda:0')

In [14]:
potential_london_ids = [12085,  4457, 10423, 15365,  9540,  3283, 12267,  8153, 11304, 15626]

In [15]:
print(sae.W_enc.shape)

torch.Size([3584, 16384])


In [16]:
def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    # tokens is shape (n_samples, seq_len)

    for i in range(0, tokens.shape[0], batch_size):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]

        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        print(max_sae_acts.shape) # (batch_size, n_features)
        break



get_top_samples(ready_tokens, potential_london_ids)



    

torch.Size([4, 16384])


In [17]:
torch.tensor([1, 2, 3]).max(dim=0)

torch.return_types.max(
values=tensor(3),
indices=tensor(2))

In [18]:
from heapq import heappush, heappushpop

def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    n_features = None  # We'll set this when we get the first batch of results
    
    # Initialize lists to store the top k values and indices for each feature
    top_k_values = []
    top_k_indices = []
    
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        
        if n_features is None:
            n_features = max_sae_acts.shape[1]
            top_k_values = [[] for _ in range(n_features)]
            top_k_indices = [[] for _ in range(n_features)]
        
        # Update top k for each feature
        for feature_idx in range(n_features):
            feature_values = max_sae_acts[:, feature_idx]
            for sample_idx, value in enumerate(feature_values):
                global_idx = i + sample_idx
                if len(top_k_values[feature_idx]) < top_k:
                    heappush(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
                elif value > top_k_values[feature_idx][0][0]:
                    heappushpop(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
    
    # Sort the results for each feature
    for feature_idx in range(n_features):
        sorted_results = sorted(zip(top_k_values[feature_idx], top_k_indices[feature_idx]), reverse=True)
        top_k_values[feature_idx] = [value for value, _ in sorted_results]
        top_k_indices[feature_idx] = [idx for _, idx in sorted_results]
    
    return top_k_values, top_k_indices

# # Usage
# top_k_values, top_k_indices = get_top_samples(ready_tokens[:8000], top_k=10, batch_size=4)


  0%|          | 7/2000 [00:14<1:07:24,  2.03s/it]


KeyboardInterrupt: 

In [19]:
# # save top_k_values and top_k_indices
# with open("top_k_values.json", "w") as f:
#     json.dump(top_k_values, f)

# with open("top_k_indices.json", "w") as f:
#     json.dump(top_k_indices, f)
    
# # save tokens
# torch.save(ready_tokens, "ready_tokens.pt")


# load top_k_values and top_k_indices
with open("top_k_values.json", "r") as f:
    top_k_values = json.load(f)

with open("top_k_indices.json", "r") as f:
    top_k_indices = json.load(f)


# load tokens
ready_tokens = torch.load("ready_tokens.pt")




  ready_tokens = torch.load("ready_tokens.pt")


In [20]:
def acts_for_feature(tokens, feature_idx: int, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    all_acts = []
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        batch_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            batch_sae_acts.append(sae_acts[:, feature_idx])
        
        batch_sae_acts = torch.stack(batch_sae_acts, dim=0)
        all_acts.append(batch_sae_acts)
    
    return torch.cat(all_acts, dim=0)

acts_for_selected = acts_for_feature(ready_tokens[:10], 2)
acts_for_selected

100%|██████████| 3/3 [00:00<00:00, 13.65it/s]


tensor([[29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [30.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [30.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16)

In [34]:
def display_highlighted_tokens(tokens, values, tokenizer):
    # Ensure tensors are on CPU
    tokens = tokens.cpu()
    values = values.cpu().to(torch.float32)
    
    original_values = values.clone()
    
    # Normalize values to range [0, 1] for coloring
    values = (values - values.min()) / (values.max() - values.min())
    
    # Function to convert value to RGB color
    def value_to_color(value):
        # Using a gradient from white to red
        color_value = int(255 * (1 - value.item()))
        return f"rgb(255, {color_value}, {color_value})"
    
    html_output = ""
    
    for batch_idx in range(tokens.shape[0]):
        # Use the original values for max activation
        max_activation = original_values[batch_idx].max().item()
        html_output += f"<p>Max Activation: {max_activation:.4f}:</p>"
        html_output += "<p>"
        
        for token, value in zip(tokens[batch_idx], values[batch_idx]):
            word = tokenizer.decode([token.item()])
            # Escape special characters to prevent HTML conflicts
            word_escaped = html.escape(word)
            color = value_to_color(value)
            html_output += f'<span style="background-color: {color};">{word_escaped}</span>'
        
        html_output += "</p>"
    
    display(HTML(html_output))

In [51]:
london = sae.W_dec[4457]
britain = sae.W_dec[10423]

In [36]:
ft_id = 4457
# ft_id = 10423


samples = ready_tokens[top_k_indices[ft_id]]

acts = acts_for_feature(samples, ft_id)
acts

display_highlighted_tokens(samples[:, 3:], acts[:, 3:], model.tokenizer)

100%|██████████| 3/3 [00:00<00:00, 12.54it/s]


In [68]:
# steer model
text = "Tell me a fun fact"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt', add_generation_prompt=True)
toks = toks.expand(10, -1)
print(toks.shape)
print(model.to_str_tokens(toks[0]))
toks = toks.to(device)


torch.Size([10, 14])
['<bos>', '<start_of_turn>', 'user', '\n', 'Tell', ' me', ' a', ' fun', ' fact', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n']


In [71]:
gen_toks = model.generate(toks, max_new_tokens=40)
model.to_string(gen_toks)

  0%|          | 0/40 [00:00<?, ?it/s]

["<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nA group of owls is called a parliament! 🦉🏛️  Isn't that a hoot?  🦉 \n\n\n\nDid you know another fun fact?<end_of_turn>\n<eos><eos><eos><eos><eos>",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nIt snows metal on Venus!  \n\nBecause Venus is so hot, the metal in its atmosphere vaporizes and then condenses into metallic snow that falls on the planet's surface. ✨🪐  ",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nDid you know that a group of owls is called a parliament? 🦉🏛️  Isn't that a hoot? 😊<end_of_turn>\n<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nIt takes about 8 minutes for light from the sun to reach Earth! ☀️🌎  That means when you see the sun, you're actually seeing it as it was 8 minutes ago",
 "<bos><start_of_turn>user\nTell me 

In [74]:
with model.hooks([(hp, partial(patch_resid, steering=britain, scale=150))]):
    gen_toks = model.generate(toks, max_new_tokens=40)
model.to_string(gen_toks)

  0%|          | 0/40 [00:00<?, ?it/s]

["<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nIt snows on average 155 days a year in the UK.  \n\nHere's another fun fact just for you: The world's smallest fire station is the  Cavendish Firestation",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nIt's a fact that the world's first computer, called the Analytical Engine, was designed by Charles Babbage in the 19th century.<end_of_turn>\n<eos><eos><eos><eos><eos><eos>",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nThe world's oldest professional football club is Sheffield Wednesday, founded in 1864.<end_of_turn>\n<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>",
 "<bos><start_of_turn>user\nTell me a fun fact<end_of_turn>\n<start_of_turn>model\nDid you know that the world's oldest coin, the penny, has been in circulation for over 1,000 years? 🪙<end_of_turn>\n<eos><eos><eos><eos><eos><e