In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

import einops

from sae_lens import SAE

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid


import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

from IPython.display import HTML, display
import html


import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f4d2c5e70a0>

In [12]:
os.environ['GEMMA_2_SAE_WEIGHTS_ROOT'] = '/workspace/gemmasaes/'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: HookedTransformer = HookedTransformer.from_pretrained("google/gemma-2-9b-it", device=device, dtype=torch.float16)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [4]:
model.tokenizer.special_tokens_map

{'bos_token': '<bos>',
 'eos_token': '<eos>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}

In [5]:
chat = [[
    {"role": "user", "content": "Write a Hello world program"},
]]
tokens = model.tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
tokens

tensor([[    2,   106,  1645,   108,  5559,   476, 25957,  2134,  2733,   107,
           108,   106,  2516,   108]])

In [6]:
gen_toks = model.generate(tokens, max_new_tokens=20)
model.to_string(gen_toks)

  0%|          | 0/20 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nWrite a Hello world program<end_of_turn>\n<start_of_turn>model\n```python\nprint("Hello, world!")\n```\n\nThis is a simple "Hello, world']

In [7]:
SEQ_LEN = 128 - 4 + 1 # 4 for the special tokens at the start, 1 because we'll cut bos.

# Load in the data (it's a Dataset object)
data = load_dataset("NeelNanda/c4-code-20k", split="train")
# print(data)
# assert isinstance(data, Dataset)

# Tokenize the data (using a utils function) and shuffle it
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=SEQ_LEN) # type: ignore
tokenized_data = tokenized_data.shuffle(42)

# Get the tokens as a tensor
all_tokens = tokenized_data["tokens"]
assert isinstance(all_tokens, torch.Tensor)

print(all_tokens.shape)

torch.Size([222989, 125])


In [8]:
pre_tokens = model.tokenizer.apply_chat_template([{"role": "user", "content": ""}], return_tensors='pt')[0][:4]
print(pre_tokens.shape)
print(model.to_str_tokens(pre_tokens))
pre_tokens

torch.Size([4])
['<bos>', '<start_of_turn>', 'user', '\n']


tensor([   2,  106, 1645,  108])

In [9]:
n_samples = all_tokens.shape[0]
ready_tokens = torch.cat([pre_tokens.expand(n_samples, -1), all_tokens[:, 1:]], dim=1)
ready_tokens.shape

torch.Size([222989, 128])

In [10]:
print(model.to_str_tokens(ready_tokens[4]))

['<bos>', '<start_of_turn>', 'user', '\n', ' instance', ' of', ' Transition', ' class', '.', '\n\n', '    ', 'Args', ':', '\n', '      ', 'sess', ' (:', 'obj', ':`', 'tf', '.', 'Session', '`', '):', ' This', ' attribute', ' represents', ' the', ' session', ' that', ' runs', '\n', '        ', 'the', ' TensorFlow', ' operations', '.', '\n', '      ', 'name', ' (', 'str', '):', ' This', ' attribute', ' represents', ' the', ' name', ' of', ' the', ' object', ' in', '\n', '        ', 'Tensor', 'Flow', "'", 's', ' op', ' Graph', '.', '\n', '      ', 'graph', ' (:', 'obj', ':`', 'tf', 'graph', '.', 'Graph', '`', '):', '  ', 'The', ' graph', ' on', ' which', ' the', ' transition', ' is', ' referred', '.', '\n', '      ', 'beta', ' (', 'float', '):', ' The', ' reset', ' probability', ' of', ' the', ' random', ' walks', ',', ' i', '.', 'e', '.', ' the', '\n', '        ', 'probability', ' that', ' a', ' user', ' that', ' sur', 'fs', ' the', ' graph', ' an', ' decides', ' to', ' jump', ' to', '\n'

In [13]:
sae_layer = 20
hp = f"blocks.{sae_layer}.hook_resid_post"
sae, _, _ = SAE.from_pretrained(
    release = "gemma-2-it-saes-old",
    sae_id = f"20/post_mlp_residual/16384/0_00045",
    device = 'cpu'
)
sae.to(device)

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

In [81]:
def top_activations(text):
    toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt')[0]
    toks = toks.to(device)

    _, acts = model.run_with_cache(toks, names_filter=hp)
    acts = acts[hp]

    all_sae_acts = []
    for batch in acts:
        sae_acts = sae.encode(batch)
        all_sae_acts.append(sae_acts)

    all_sae_acts = torch.stack(all_sae_acts, dim=0)
    top_v, top_i = torch.topk(all_sae_acts, 10, dim=-1)
    return top_v[:, 4:-2], top_i[:, 4:-2]

top_v, top_i = top_activations("Give me a random city name")
print(top_i)

tensor([[[ 6802,  4230, 15911,  5388, 12650, 12880, 11693, 14213,  1740,  5672],
         [ 6802,  4230, 15911,  5157, 11316,  5388, 11693, 13568, 11623, 12650],
         [ 5075,  7567,  6802,  4230, 15911,  8298,  5906, 13820,  5388,   204],
         [ 4223,  6802, 11085,  5682, 15911,  3442,  5906, 10396,  6228,  7075],
         [ 6802,  9220, 10312, 11311,  5906, 14221, 15882,  1024,  5272,  4303],
         [ 6802,    96,  7547, 10312,  5906,  2216, 11085, 13854,  3442, 13711]]],
       device='cuda:0')


In [56]:
text = "I traveled to London"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt')[0]
print(model.to_str_tokens(toks))
toks = toks.to(device)

_, acts = model.run_with_cache(toks, names_filter=hp)
acts = acts[hp]

all_sae_acts = []
for batch in acts:
    sae_acts = sae.encode(batch)
    all_sae_acts.append(sae_acts)

all_sae_acts = torch.stack(all_sae_acts, dim=0)

['<bos>', '<start_of_turn>', 'user', '\n', 'I', ' traveled', ' to', ' London', '<end_of_turn>', '\n']


In [57]:
top_v, top_i = torch.topk(all_sae_acts, 10, dim=-1)
top_i

tensor([[[ 6802,  2122, 11299,  7697,   589, 14682,  4016,  5827, 10002,  1412],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444,  3892,  2122,  7806],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444,  7806,  3892, 11271],
         [ 6802,  7697,  5823, 14587, 11299,  2599,  3444, 11271,  2122,  3892],
         [ 6802,  4230, 16078,  2231, 13568, 12571, 14587,  2599,     0,     1],
         [ 6802, 11233,  2028, 13568, 10025,  7247,  5662, 15709,  6356, 14211],
         [14006,  6802,  2028, 10097,  7247,  7041, 11233, 15355,  9158,  2457],
         [ 6802,  5771,   484,  5687,  1106,  8177, 10688, 15882,  5726, 10342],
         [ 6802,  8163,  2738,  5662, 13491,  2060,  3312,  7556, 16362,  4637],
         [ 6802,  7247,   436,  3312,  2028,  7720,  5687, 14289,  5662, 15760]]],
       device='cuda:0')

In [58]:
potential_london_ids = [ 6802,  5771,   484,  5687,  1106,  8177, 10688, 15882,  5726, 10342],

In [17]:
print(sae.W_enc.shape)

torch.Size([3584, 16384])


In [60]:
def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    # tokens is shape (n_samples, seq_len)

    for i in range(0, tokens.shape[0], batch_size):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]

        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        print(max_sae_acts.shape) # (batch_size, n_features)
        print(max_sae_acts)
        break



get_top_samples(ready_tokens, potential_london_ids)



    

torch.Size([4, 16384])
tensor([[ 0.0000, 45.7500,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 45.7500,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 45.7500,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 45.7500,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16)


In [20]:
torch.tensor([1, 2, 3]).max(dim=0)

torch.return_types.max(
values=tensor(3),
indices=tensor(2))

In [24]:
from heapq import heappush, heappushpop

def get_top_samples(tokens, top_k: int = 10, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    n_features = None  # We'll set this when we get the first batch of results
    
    # Initialize lists to store the top k values and indices for each feature
    top_k_values = []
    top_k_indices = []
    
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        max_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            max_sae_acts.append(sae_acts.max(dim=0).values)
        
        max_sae_acts = torch.stack(max_sae_acts, dim=0)
        
        if n_features is None:
            n_features = max_sae_acts.shape[1]
            top_k_values = [[] for _ in range(n_features)]
            top_k_indices = [[] for _ in range(n_features)]
        
        # Update top k for each feature
        for feature_idx in range(n_features):
            feature_values = max_sae_acts[:, feature_idx]
            for sample_idx, value in enumerate(feature_values):
                global_idx = i + sample_idx
                if len(top_k_values[feature_idx]) < top_k:
                    heappush(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
                elif value > top_k_values[feature_idx][0][0]:
                    heappushpop(top_k_values[feature_idx], (value.item(), global_idx))
                    top_k_indices[feature_idx] = [idx for _, idx in top_k_values[feature_idx]]
    
    # Sort the results for each feature
    for feature_idx in range(n_features):
        sorted_results = sorted(zip(top_k_values[feature_idx], top_k_indices[feature_idx]), reverse=True)
        top_k_values[feature_idx] = [value for value, _ in sorted_results]
        top_k_indices[feature_idx] = [idx for _, idx in sorted_results]
    
    return top_k_values, top_k_indices

# # Usage
# top_k_values, top_k_indices = get_top_samples(ready_tokens[:8000], top_k=10, batch_size=4)


  0%|          | 10/2000 [00:16<53:06,  1.60s/it]


KeyboardInterrupt: 

In [54]:
# # save top_k_values and top_k_indices
# with open("top_k_values.json", "w") as f:
#     json.dump(top_k_values, f)

# with open("top_k_indices.json", "w") as f:
#     json.dump(top_k_indices, f)
    
# # save tokens
# torch.save(ready_tokens, "ready_tokens.pt")


# load top_k_values and top_k_indices
with open("top_k_values_20.json", "r") as f:
    top_k_values = json.load(f)

with open("top_k_indices_20.json", "r") as f:
    top_k_indices = json.load(f)


# load tokens
ready_tokens = torch.load("ready_tokens_20.pt")




  ready_tokens = torch.load("ready_tokens_20.pt")


In [61]:
def acts_for_feature(tokens, feature_idx: int, batch_size: int = 4):
    n_samples, seq_len = tokens.shape
    all_acts = []
    for i in tqdm(range(0, n_samples, batch_size)):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp, stop_at_layer=sae_layer+1)
        acts = acts[hp]
        batch_sae_acts = []
        for sample in acts:
            sae_acts = sae.encode(sample)
            batch_sae_acts.append(sae_acts[:, feature_idx])
        
        batch_sae_acts = torch.stack(batch_sae_acts, dim=0)
        all_acts.append(batch_sae_acts)
    
    return torch.cat(all_acts, dim=0)

acts_for_selected = acts_for_feature(ready_tokens[:10], 6802)
acts_for_selected

100%|██████████| 3/3 [00:00<00:00, 10.44it/s]


tensor([[3152., 1088., 1144.,  ...,    0.,    0.,    0.],
        [3152., 1088., 1144.,  ...,   75.,   50.,   36.],
        [3152., 1088., 1144.,  ...,   69.,   67.,    0.],
        ...,
        [3152., 1088., 1144.,  ...,   59.,   54.,   49.],
        [3152., 1088., 1144.,  ...,    0.,   48.,    0.],
        [3152., 1088., 1144.,  ...,   75.,    0.,    8.]], device='cuda:0',
       dtype=torch.bfloat16)

In [62]:
def display_highlighted_tokens(tokens, values, tokenizer):
    # Ensure tensors are on CPU
    tokens = tokens.cpu()
    values = values.cpu().to(torch.float32)
    
    original_values = values.clone()
    
    # Normalize values to range [0, 1] for coloring
    values = (values - values.min()) / (values.max() - values.min())
    
    # Function to convert value to RGB color
    def value_to_color(value):
        # Using a gradient from white to red
        color_value = int(255 * (1 - value.item()))
        return f"rgb(255, {color_value}, {color_value})"
    
    html_output = ""
    
    for batch_idx in range(tokens.shape[0]):
        # Use the original values for max activation
        max_activation = original_values[batch_idx].max().item()
        html_output += f"<p>Max Activation: {max_activation:.4f}:</p>"
        html_output += "<p>"
        
        for token, value in zip(tokens[batch_idx], values[batch_idx]):
            word = tokenizer.decode([token.item()])
            # Escape special characters to prevent HTML conflicts
            word_escaped = html.escape(word)
            color = value_to_color(value)
            html_output += f'<span style="background-color: {color};">{word_escaped}</span>'
        
        html_output += "</p>"
    
    display(HTML(html_output))

In [113]:
top_v, top_i = top_activations("""Do not go gentle into that good night,
Old age should burn and rave at close of day;
Rage, rage against the dying of the light.

Though wise men at their end know dark is right,
Because their words had forked no lightning they
Do not go gentle into that good night.

Good men, the last wave by, crying how bright
Their frail deeds might have danced in a green bay,
Rage, rage against the dying of the light.
""")
print(top_i)

tensor([[[ 6802,  4230, 12650,  7985, 11350, 15376, 15701,  3234,  5906,  8641],
         [14213,  6802, 15911, 12358,  7311,  3323, 11836, 14915,  8682, 13718],
         [ 4505, 12358,  8682, 11087, 15709,   101, 12414, 10556,    23, 10561],
         [16083, 11087,   101,  1178, 12342,  6802, 10783,  3125, 15882, 10098],
         [12568, 11087, 10162, 10294, 12166,  1178,  6919,  1898,  3307, 10007],
         [ 6802,   161, 11087,  1962, 12166,  8042,  3025,  1178,  9363, 13442],
         [ 6767,  6802,  1962, 11176,  9680, 12206, 12166, 11087,  3025,  7002],
         [12393, 11087,  1178,  3928,  3125, 12166,   101,  3307,  6919, 10098],
         [ 6802,   101, 11087,   824,  9348, 15718, 12988,  7155,  2738,  1178],
         [ 6802,   101,  9348,   824, 14650,  7792,  2738,  1178,  7069, 12013],
         [ 3800,  3539,   517,  6919,  1470,  1178, 14054,  5324,  3125,  3307],
         [ 3539,  6802, 12904,  3800,  6919,    13, 11160, 11176,  8129,  5611],
         [15998,  6919,   74

In [132]:
def visualize_feature(ft_id: int):
    samples = ready_tokens[top_k_indices[ft_id]]
    acts = acts_for_feature(samples, ft_id)
    display_highlighted_tokens(samples[:, 3:], acts[:, 3:], model.tokenizer)

In [120]:
# ft_id = 4457
ft_id = 101

samples = ready_tokens[top_k_indices[ft_id]]


acts = acts_for_feature(samples, ft_id)

display_highlighted_tokens(samples[:, 3:], acts[:, 3:], model.tokenizer)

100%|██████████| 3/3 [00:00<00:00, 10.65it/s]

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 13.7500],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 10.3125],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  7.2500],
        [ 0.0000,  0.0000,  0.0000,  ..., 12.7500,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16)





In [96]:
# steer model
text = "Tell me a story"
toks = model.tokenizer.apply_chat_template([{"role": "user", "content": text}], return_tensors='pt', add_generation_prompt=True)
toks = toks.expand(10, -1)
print(toks.shape)
print(model.to_str_tokens(toks[0]))
toks = toks.to(device)


torch.Size([10, 13])
['<bos>', '<start_of_turn>', 'user', '\n', 'Tell', ' me', ' a', ' story', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n']


In [97]:
gen_toks = model.generate(toks, max_new_tokens=40)
model.to_string(gen_toks)

  0%|          | 0/40 [00:00<?, ?it/s]

["<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse keeper, Elias, squinted at the churning sea, its grey waves crashing against the jagged rocks below. For 50 years, he'd stood watch from this lonely post,",
 "<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse keeper, Elias, squinted at the graying sky. For fifty years, he'd watched the sun rise and set over the treacherous Rockhaven Reef, its stark white lighthouse beam",
 '<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe wind whispered secrets through the long, dry grass of the savanna, secrets Elara strained to hear. She was a young gazelle, barely a year old, and today was her first solo hunt',
 '<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old woman, Elara, knelt beside the overgrown rosebush, her wrinkled hands brushing against thorns as she searched for a bloom. The t

In [122]:
# # Layer 9
# london = sae.W_dec[4457]
# britain = sae.W_dec[10423]

# Layer 20
britain = sae.W_dec[5771]
# london = sae.W_dec[5771]
happy = sae.W_dec[9022]
poems = sae.W_dec[101]

In [123]:
with model.hooks([(hp, partial(patch_resid, steering=poems, scale=150))]):
    gen_toks = model.generate(toks, max_new_tokens=40)
model.to_string(gen_toks)

  0%|          | 0/40 [00:00<?, ?it/s]

['<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse stood sentinel on the craggy cliff,\nA lonely heart, a weathered soul,\nWhispering tales of crashing waves and drifting ships.\n\nIt spoke of storms that tore the sky',
 "<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse keeper, bent with age and salt, watched the tide creep in,\n\nHe'd seen a thousand sunsets on this craggy, lonely shore,\nFelt the ocean's pulse",
 '<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse stood weathered and worn,\nUpon a craggy cliff, its spirit torn.\nIt watched the waves crash, ceaseless, strong,\nA testament to love, both sweet and long',
 '<bos><start_of_turn>user\nTell me a story<end_of_turn>\n<start_of_turn>model\nThe old lighthouse keeper, weathered and worn like the cliffs he climbed, watched the waves crash against the shore. The sea was restless tonight, its

In [124]:
unsteered_texts = generate(model,
        hooks=[],
        max_new_tokens=35,
        prompt=text,
        batch_size=64,
        n_samples=200,
        )


In [125]:

happy_texts = generate(model,
        hooks=[(hp, partial(patch_resid, steering=happy, scale=60))],
        max_new_tokens=35,
        prompt=text,
        batch_size=64,
        n_samples=200,
        )

In [140]:
poem_texts = generate(model,
                      hooks=[(hp, partial(patch_resid, steering=poems, scale=150))],
                      max_new_tokens=35,
                      prompt=text,
                      batch_size=64,
                      n_samples=200,
                      )

In [126]:
(happy_texts[:5])

['Tell me a story about a little girl who loved to draw.\n\nLila was a whirlwind of glitter and crayons. From the moment her tiny toes touched the floor each morning, she was armed with',
 'Tell me a story about a baker who bakes bread with magical ingredients for magical beasts.\n\nBertram Buttercrust wasn\'t your average baker. His shop, "Bertram\'s Bites',
 "Tell me a story about a time when you helped someone.\n\nThough I don't have personal experiences like humans do, I can tell you a story about a time I helped someone using the",
 'Tell me a story about a young girl named Emma.\n\nEmma lived in a tiny village nestled among rolling hills, where the air was sweet with the scent of fresh bread and blooming wildflowers. She',
 'Tell me a story about an old man who finds an empty bottle at the beach.\n\nElias, with his weathered face etched with the lines of a life well-lived and sun-kissed']

In [127]:
@torch.no_grad()
def get_feature_freqs(texts: list[str], model: HookedTransformer, sae: SAE, hook_point: str):
    all_sae_acts = torch.zeros(sae.cfg.d_sae, device=sae.W_enc.device)
    count = 0

    for text in tqdm(texts):
        _, acts = model.run_with_cache(text, names_filter=hook_point)
        acts = acts[hook_point]

        for batch in acts:
            sae_acts = sae.encode(batch)
            all_sae_acts += sae_acts.sum(dim=0)
            count += acts.shape[0]
    return all_sae_acts / count


In [129]:
unsteered_freqs = get_feature_freqs(unsteered_texts, model, sae, hp)

100%|██████████| 200/200 [00:22<00:00,  8.85it/s]


In [130]:
happy_freqs = get_feature_freqs(happy_texts, model, sae, hp)

100%|██████████| 200/200 [00:21<00:00,  9.14it/s]


In [141]:
poem_freqs = get_feature_freqs(poem_texts, model, sae, hp)

100%|██████████| 200/200 [00:22<00:00,  8.90it/s]


In [142]:
# happy_diff = happy_freqs - unsteered_freqs

poem_diff = poem_freqs - unsteered_freqs
top_v, top_i = torch.topk(poem_diff, 10, dim=-1)
print(top_i)
print(top_v)

tensor([  990,  5611,  1178, 10871,  7644, 12169,  7603,  6919,  1401,  9702],
       device='cuda:0')
tensor([384.7408, 258.7422, 130.3225, 109.7412,  89.7206,  79.5050,  78.9244,
         76.9000,  76.6184,  73.4022], device='cuda:0')


In [144]:
poem_texts[:10]

["Tell me a story about the moon,\n\nA silent watchtower of the night's vast sky,\n\n**The Moon's Lament**\n\nMy silver skin remembers every tear that falls,",
 'Tell me a story about the sun and the Moon.\n\nThe sun, a fiery king, ruled the day,\nHis golden chariot climbed high and bright,\nHe sang of joy, of',
 'Tell me a story about a broken world.\n\nThe sky hangs gray, a tapestry of ash and dust,\nA world choked with sorrow, a heart turned to rust.\n\nThe fields lie',
 'Tell me a story about a love unrequited.\n\nThe wind whispers your name, oh willow,\nAnd my heart, a captive, falls in line.\nYou dance of freedom,',
 'Tell me a story about a robot, looking out over a city.\n\nThe city hummed beneath its metallic gaze, \na tapestry of light and shadow spun with haste.\nSteel bones',
 'Tell me a story about a lonely cloud,\n\n\nA lonely cloud drifts on the breeze,\n\nA tale of longing and unseen ease.\n\nUpon a windswept heath it sighed,\nAching',
 'Tell me a story about a tiny hou

In [147]:
visualize_feature(10871)

100%|██████████| 3/3 [00:00<00:00, 10.49it/s]


ValueError: cannot convert float NaN to integer