In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

import numpy as np


from functools import partial
from datasets import load_dataset
from tqdm import tqdm
import json

from typing import List, Callable, Union, Optional, Literal

import einops

from sae_lens import SAE

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid


import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

from IPython.display import HTML, display
import html


import numpy as np
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f26ada7a530>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
@torch.no_grad()
def normalise_decoder(sae, scale_input=False):
    """
    Normalises the decoder weights of the SAE to have unit norm.

    Use this when loading for gemma-2b saes.

    Args:
        sae (SparseAutoencoder): The sparse autoencoder.
        scale_input (bool): Use this when loading layer 12 model.
    """
    norms = torch.norm(sae.W_dec, dim=1)
    sae.W_dec /= norms[:, None]
    sae.W_enc *= norms[None, :]
    sae.b_enc *= norms

In [4]:
# Loading in layer 6 SAE, there is also layer 11 sae and soon all layers
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [7]:
prompts = [
    "",
    "I think",
    "Breaking news",
    "Last night",
    "For sale",
    "The weather",
    "Dear Sir/Madam",
    "Preheat the oven",
    "It's interesting that"
    "Assistant:",
    "I went up to",
    "New study suggests",
    ]


def gen(ft_id=None, scale=60, batch_size=64, max_toks=32, n_batches=1, verbose=False):
    if ft_id is not None:
        steer = sae6.W_dec[ft_id]
        hooks = [(hp6, partial(patch_resid, steering=steer, scale=scale))]
    else:
        hooks = []
    generated_tokens = []
    for prompt in tqdm(prompts, disable=not verbose):
        tokens = model.to_tokens(prompt, prepend_bos=True)
        prompt_batch = tokens.expand(batch_size, -1)
        for _ in range(n_batches):
            with model.hooks(hooks):
                gen_batch = model.generate(
                    prompt_batch,
                    max_new_tokens=max_toks - tokens.shape[-1],
                    top_k=50,
                    top_p=0.3,
                    verbose=False,
                )
            generated_tokens.append(gen_batch)
    return torch.cat(generated_tokens, dim=0)


def get_feature_acts(tokens, batch_size):
    assert tokens.shape[1] == 32
    all_sae_acts = torch.zeros(sae6.W_dec.shape[0], device=device)
    count = 0
    for i in range(0, tokens.shape[0], batch_size):
        batch = tokens[i:i+batch_size]
        _, acts = model.run_with_cache(batch, names_filter=hp6, stop_at_layer=7)
        acts = acts[hp6] # shape (batch_size, len, d_model)
        acts = acts.reshape(-1, acts.shape[-1]) # shape (batch_size * len, d_model)
        sae_acts = sae6.encode(acts)
        all_sae_acts += sae_acts.sum(dim=0)
        count += sae_acts.shape[0]
    return all_sae_acts / count


print("Getting baseline feature acts")
baseline_dist = get_feature_acts(gen(n_batches=10, verbose=True), 64) ### n_batches=10
print("done baseline")

Getting baseline feature acts


100%|██████████| 11/11 [03:04<00:00, 16.78s/it]


done baseline


In [None]:
# Convert to NumPy array and save
baseline_dist_np = baseline_dist.cpu().numpy()
np.save('baseline_dist.npy', baseline_dist_np)
print("Baseline distribution saved as 'baseline_dist.npy'")

In [8]:
feature_dict = {}

# british
features = [4343, 9043, 4893, 5823, 16322, 15548, 1411, 8289, 13568, 15582, 12090, 10138]

# california
features += [15691, 6427, 5388, 10576, 6367, 10164, 15601, 10200]

# canada
features += [8048, 8030, 3427, 3167]

# china
features += [5191, 11403, 280, 12017]

# taiwan
features += [5751]

# hong kong
features += [7539]

# Russia
features += [5191, 9765, 10500]

# poland
features += [13867]

# texas
features += [1010, 7389, 14089]

# europe
features += [11255]



In [9]:
import numpy as np

feature_dict = {}

for feature in tqdm(features):
    ft_dist = get_feature_acts(gen(feature), 64)
    diff = ft_dist - baseline_dist
    feature_dict[feature] = diff.cpu().numpy()  # Convert to NumPy array

# Convert the dictionary to a NumPy array
feature_array = np.array(list(feature_dict.values()))

# Save the NumPy array
np.save('feature_effects.npy', feature_array)

print("Feature effects saved as 'feature_effects.npy'")

# If you also want to save the feature IDs:
np.save('feature_ids.npy', np.array(features))
print("Feature IDs saved as 'feature_ids.npy'")

100%|██████████| 38/38 [12:42<00:00, 20.07s/it]

Feature effects saved as 'feature_effects.npy'
Feature IDs saved as 'feature_ids.npy'





In [16]:
feature_effects = np.load('feature_effects.npy')
feature_ids = np.load('feature_ids.npy')

# To get the effect for a specific feature:
def get_feature_effects(feature):
    feature_index = np.where(feature_ids == feature)[0][0]
    return feature_effects[feature_index]

def get_feature_scores(feature_effects, features=[]):
    return list(zip(feature_effects[features], features))
        

for point in get_feature_scores(get_feature_effects(5751), features):
    print(point)


(-0.07534706, 4343)
(-0.0028025392, 9043)
(-0.0010869228, 4893)
(-0.010105142, 5823)
(-0.0023378618, 16322)
(-0.0017658686, 15548)
(-0.0056179943, 1411)
(-0.0025995323, 8289)
(-0.0134180905, 13568)
(-0.0028848588, 15582)
(-0.035340704, 12090)
(-0.013917395, 10138)
(-0.014327853, 15691)
(-0.008736513, 6427)
(-0.003220141, 5388)
(-0.010402888, 10576)
(-0.0029031658, 6367)
(-0.006775529, 10164)
(-0.007083768, 15601)
(-0.0029341704, 10200)
(0.0073608505, 8048)
(-0.012369858, 8030)
(-0.0029714354, 3427)
(-0.008502226, 3167)
(0.00030197762, 5191)
(0.14417163, 11403)
(0.092411056, 280)
(0.14542435, 12017)
(0.7932664, 5751)
(0.023558091, 7539)
(0.00030197762, 5191)
(0.0009038579, 9765)
(-0.00429634, 10500)
(-0.0021733611, 13867)
(-0.0028652225, 1010)
(-0.0018561237, 7389)
(-0.0018383735, 14089)
(0.18948841, 11255)


In [18]:
positive_effects = [(effect, feature) for effect, feature in get_feature_scores(get_feature_effects(5751), features) if effect < 0]
sorted_positive_effects = sorted(positive_effects, key=lambda x: x[0], reverse=False)

for effect, feature in sorted_positive_effects:
    print(f"({effect:.6f}, {feature})")

(-0.075347, 4343)
(-0.035341, 12090)
(-0.014328, 15691)
(-0.013917, 10138)
(-0.013418, 13568)
(-0.012370, 8030)
(-0.010403, 10576)
(-0.010105, 5823)
(-0.008737, 6427)
(-0.008502, 3167)
(-0.007084, 15601)
(-0.006776, 10164)
(-0.005618, 1411)
(-0.004296, 10500)
(-0.003220, 5388)
(-0.002971, 3427)
(-0.002934, 10200)
(-0.002903, 6367)
(-0.002885, 15582)
(-0.002865, 1010)
(-0.002803, 9043)
(-0.002600, 8289)
(-0.002338, 16322)
(-0.002173, 13867)
(-0.001856, 7389)
(-0.001838, 14089)
(-0.001766, 15548)
(-0.001087, 4893)


In [21]:
get_feature_effects(5751)

array([-1.8843119e-03, -2.5397148e-03,  0.0000000e+00, ...,
        0.0000000e+00,  2.3324606e-03, -5.4859228e-09], dtype=float32)

In [25]:
import numpy as np
import plotly.graph_objects as go

# Get the feature effects for feature 5751
effects = get_feature_effects(10138)

# Create an array of indices
indices = np.arange(len(effects))

# Sort the effects and indices together
sorted_indices = np.argsort(effects)[::-1]
sorted_effects = effects[sorted_indices]
sorted_original_indices = indices[sorted_indices]

# Create the bar chart
fig = go.Figure(data=[go.Bar(
    x=np.arange(len(sorted_effects)),
    y=sorted_effects,
    hovertext=[f"Original Index: {idx}, Effect: {effect:.6f}" for idx, effect in zip(sorted_original_indices, sorted_effects)],
    hoverinfo='text'
)])

# Update the layout
fig.update_layout(
    title='Feature Effects for Feature 5751 (Sorted)',
    xaxis_title='Sorted Position',
    yaxis_title='Effect Magnitude',
    width=1200,
    height=600
)

# Show the plot
fig.show()