In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

import einops

from sae_lens import SAE

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d


import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd


import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb48c45d870>

In [2]:
os.environ['GEMMA_2_SAE_WEIGHTS_ROOT'] = '/workspace/gemmasaes'

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model: HookedTransformer = HookedTransformer.from_pretrained("google/gemma-2-9b-it", device=device, dtype=torch.float16)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [4]:
model.tokenizer.special_tokens_map

{'bos_token': '<bos>',
 'eos_token': '<eos>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}

In [5]:
chat = [[
    {"role": "user", "content": "Write a Hello world program"},
]]
tokens = model.tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
tokens

tensor([[    2,   106,  1645,   108,  5559,   476, 25957,  2134,  2733,   107,
           108,   106,  2516,   108]])

In [6]:
gen_toks = model.generate(tokens, max_new_tokens=20)
model.to_string(gen_toks)

  0%|          | 0/20 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nWrite a Hello world program<end_of_turn>\n<start_of_turn>model\n```python\nprint("Hello, world!")\n```\n\n**Explanation:**\n\n* **`print']

In [7]:
SEQ_LEN = 128 - 4 + 1 # 4 for the special tokens at the start, 1 because we'll cut bos.

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
# print(data)
# assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

torch.Size([222989, 125])


In [8]:
pre_tokens = model.tokenizer.apply_chat_template([{"role": "user", "content": ""}], return_tensors='pt')[0][:4]
print(pre_tokens.shape)
print(model.to_str_tokens(pre_tokens))
pre_tokens

torch.Size([4])
['<bos>', '<start_of_turn>', 'user', '\n']


tensor([   2,  106, 1645,  108])

In [9]:
n_samples = all_tokens.shape[0]
ready_tokens = torch.cat([pre_tokens.expand(n_samples, -1), all_tokens[:, 1:]], dim=1)
ready_tokens.shape

torch.Size([222989, 128])

In [10]:
print(model.to_str_tokens(ready_tokens[4]))

['<bos>', '<start_of_turn>', 'user', '\n', ' instance', ' of', ' Transition', ' class', '.', '\n\n', '    ', 'Args', ':', '\n', '      ', 'sess', ' (:', 'obj', ':`', 'tf', '.', 'Session', '`', '):', ' This', ' attribute', ' represents', ' the', ' session', ' that', ' runs', '\n', '        ', 'the', ' TensorFlow', ' operations', '.', '\n', '      ', 'name', ' (', 'str', '):', ' This', ' attribute', ' represents', ' the', ' name', ' of', ' the', ' object', ' in', '\n', '        ', 'Tensor', 'Flow', "'", 's', ' op', ' Graph', '.', '\n', '      ', 'graph', ' (:', 'obj', ':`', 'tf', 'graph', '.', 'Graph', '`', '):', '  ', 'The', ' graph', ' on', ' which', ' the', ' transition', ' is', ' referred', '.', '\n', '      ', 'beta', ' (', 'float', '):', ' The', ' reset', ' probability', ' of', ' the', ' random', ' walks', ',', ' i', '.', 'e', '.', ' the', '\n', '        ', 'probability', ' that', ' a', ' user', ' that', ' sur', 'fs', ' the', ' graph', ' an', ' decides', ' to', ' jump', ' to', '\n'

In [15]:
sae_layer = 9
hp = f"blocks.{sae_layer}.hook_resid_post"
sae, _, _ = SAE.from_pretrained(
    release = "gemma-2-it-saes",
    sae_id = "9/post_mlp_residual/16384/0_00045",
    device = 'cpu'
)
sae.to(device)

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [12]:
text = "I traveled to London"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt')[0]
print(model.to_str_tokens(toks))
toks = toks.to(device)

_, acts = model.run_with_cache(toks, names_filter=hp)
acts = acts[hp]

all_sae_acts = []
for batch in acts:
    sae_acts = sae.encode(batch)
    all_sae_acts.append(sae_acts)

all_sae_acts = torch.stack(all_sae_acts, dim=0)

['<bos>', '<start_of_turn>', 'user', '\n', 'I', ' traveled', ' to', ' London', '<end_of_turn>', '\n']


In [13]:
top_v, top_i = torch.topk(all_sae_acts, 10, dim=-1)
top_i

tensor([[[12085, 12324,    35, 11637, 11355,  8984,   466,  7469,  2832, 10547],
         [12085, 11355,  5600, 15460,  1578,  9540,  8961,    35, 10531,  9385],
         [12085,  4839, 15256, 12891,  8968, 15460, 11264,  9540, 10720,  6415],
         [12085,   195, 15256, 15460,  9540,  3767,  8319,  1217, 10720,  2088],
         [12085, 11306,  2725, 15460,  4660,  9540,  9621, 10720,  1899,  8489],
         [ 8209, 12085,  9540, 10241, 12532,  8489, 11608,  3326,  8379, 15452],
         [ 8209,  7678, 12085,  9540,  5303,  8489, 12532, 10241, 11608,  8941],
         [12085,  4457, 10423, 15365,  9540,  3283, 12267,  8153, 11304, 15626],
         [12085,  5600,  3879,  9540, 12532,  2459, 10241, 16219, 14263, 11857],
         [ 9142,  9540, 12085, 16219,  2459, 12532, 16135, 15743, 10241,  8209]]],
       device='cuda:0')

In [14]:
potential_london_ids = [12085,  4457, 10423, 15365,  9540,  3283, 12267,  8153, 11304, 15626]

In [17]:
print(sae.W_enc.shape)

torch.Size([3584, 16384])


In [29]:
def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    # tokens is shape (n_samples, seq_len)

    for i in range(0, tokens.shape[0], batch_size):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]

        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        print(max_sae_acts.shape) # (batch_size, n_features)
        break



get_top_samples(ready_tokens, potential_london_ids)



    

torch.Size([4, 16384])


In [28]:
torch.tensor([1, 2, 3]).max(dim=0)

torch.return_types.max(
values=tensor(3),
indices=tensor(2))

In [99]:
from heapq import heappush, heappushpop

def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    n_features = None  # We'll set this when we get the first batch of results
    
    # Initialize lists to store the top k values and indices for each feature
    top_k_values = []
    top_k_indices = []
    
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        
        if n_features is None:
            n_features = max_sae_acts.shape[1]
            top_k_values = [[] for _ in range(n_features)]
            top_k_indices = [[] for _ in range(n_features)]
        
        # Update top k for each feature
        for feature_idx in range(n_features):
            feature_values = max_sae_acts[:, feature_idx]
            for sample_idx, value in enumerate(feature_values):
                global_idx = i + sample_idx
                if len(top_k_values[feature_idx]) < top_k:
                    heappush(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
                elif value > top_k_values[feature_idx][0][0]:
                    heappushpop(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
    
    # Sort the results for each feature
    for feature_idx in range(n_features):
        sorted_results = sorted(zip(top_k_values[feature_idx], top_k_indices[feature_idx]), reverse=True)
        top_k_values[feature_idx] = [value for value, _ in sorted_results]
        top_k_indices[feature_idx] = [idx for _, idx in sorted_results]
    
    return top_k_values, top_k_indices

# Usage
top_k_values, top_k_indices = get_top_samples(ready_tokens[:1000], top_k=10, batch_size=4)


100%|██████████| 250/250 [08:27<00:00,  2.03s/it]


In [100]:
top_k_values

[[(35.5, 510),
  (33.75, 680),
  (33.0, 930),
  (32.0, 90),
  (31.625, 957),
  (31.625, 161),
  (29.875, 704),
  (29.875, 469),
  (29.375, 487),
  (28.875, 516)],
 [(37.25, 358),
  (36.0, 483),
  (34.75, 936),
  (33.5, 417),
  (6.625, 910),
  (6.1875, 79),
  (6.0625, 277),
  (5.75, 841),
  (5.25, 746),
  (4.875, 430)],
 [(29.875, 9),
  (29.875, 8),
  (29.875, 7),
  (29.875, 6),
  (29.875, 5),
  (29.875, 4),
  (29.875, 3),
  (29.875, 2),
  (29.875, 1),
  (29.875, 0)],
 [(19.75, 669),
  (7.90625, 485),
  (7.84375, 83),
  (7.15625, 183),
  (5.84375, 886),
  (5.84375, 258),
  (5.40625, 764),
  (5.34375, 579),
  (5.15625, 299),
  (4.71875, 524)],
 [(15.8125, 789),
  (13.375, 478),
  (13.3125, 920),
  (13.3125, 571),
  (12.4375, 110),
  (11.0, 941),
  (9.5625, 607),
  (6.5, 300),
  (6.40625, 39),
  (6.3125, 44)],
 [(47.0, 9),
  (47.0, 8),
  (47.0, 7),
  (47.0, 6),
  (47.0, 5),
  (47.0, 4),
  (47.0, 3),
  (47.0, 2),
  (47.0, 1),
  (47.0, 0)],
 [(64.5, 487),
  (62.25, 713),
  (58.25, 828),
  (

In [101]:
def acts_for_feature(tokens, feature_idx: int, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    all_acts = []
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        batch_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            batch_sae_acts.append(sae_acts[:, feature_idx])
        
        batch_sae_acts = torch.stack(batch_sae_acts, dim=0)
        all_acts.append(batch_sae_acts)
    
    return torch.cat(all_acts, dim=0)

acts_for_selected = acts_for_feature(ready_tokens[:10], 2)
acts_for_selected

100%|██████████| 3/3 [00:00<00:00, 13.91it/s]


tensor([[29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [29.8750,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [30.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [30.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16)

In [102]:
from IPython.display import HTML, display
import html
def display_highlighted_tokens(tokens, values, tokenizer):
    # Ensure tensors are on CPU
    tokens = tokens.cpu()
    values = values.cpu().to(torch.float32)
    
    # Normalize values to range [0, 1]
    values = (values - values.min()) / (values.max() - values.min())
    
    # Function to convert value to RGB color
    def value_to_color(value):
        # Using a gradient from white to red
        color_value = int(255 * (1 - value.item()))
        return f"rgb(255, {color_value}, {color_value})"
    
    html_output = ""
    
    for batch_idx in range(tokens.shape[0]):
        html_output += f"<p>Batch {batch_idx}:</p>"
        html_output += "<p>"
        
        for token, value in zip(tokens[batch_idx], values[batch_idx]):
            word = tokenizer.decode([token.item()])
            # Escape special characters to prevent HTML conflicts
            word_escaped = html.escape(word)
            color = value_to_color(value)
            html_output += f'<span style="background-color: {color};">{word_escaped}</span>'
        
        html_output += "</p>"
    
    display(HTML(html_output))

In [107]:
# now let's see the top activations for feature id 12085
# ft_id = 4457
ft_id = 10423

samples = ready_tokens[top_k_indices[ft_id]]

acts = acts_for_feature(samples, ft_id)
acts

display_highlighted_tokens(samples[:, 3:], acts[:, 3:], model.tokenizer)

100%|██████████| 3/3 [00:00<00:00, 12.66it/s]
