In [1]:
import os
import sys
import json
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d, scores_clamp_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fae2dbdd570>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"
sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [22]:
feature_names = ["anger", "wedding", "london", "castle", "writing"]
feature_ids = [1062, 8406, 10138, 10473, 1058]

vecs = [sae6.W_enc[:, id] for id in feature_ids]

# just eye-balling coherence*score plots.
# anger 40-70 -> 55
# wedding 50-70 -> 60
# london 50-90? -> 70
# castle 50-80 -> 65
# writing 40-60 -> 50

scales = [55, 60, 70, 65, 50]



In [5]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=8)

In [20]:
max_acts = [0] * len(vecs)
# get max activations for each feature
for i, batch in enumerate(loader):
    _, resid = model.run_with_cache(batch["tokens"], prepend_bos=False, names_filter=hp6)
    resid = resid[hp6][:, 1:, :] # no bos
    for vi, vec in enumerate(vecs):
        acts = resid @ vec
        m = acts.max()
        if m > max_acts[vi]:
            max_acts[vi] = m.item()
    # if i >= 100:
    #     break

KeyboardInterrupt: 

In [21]:
print(max_acts)

[53.004146575927734, 83.18545532226562, 83.4236831665039, 37.6751594543457, 70.49939727783203]


In [35]:
fig = px.scatter(x=max_acts, y=scales, labels={"x": "max activation", "y": "optimal norm"}, title="Approx optimal norm vs max activation")
fig.update_traces(marker=dict(symbol='x', size=8, color='blue'))
# Add y=x line
fig.add_shape(
    type="line",
    x0=min(max_acts), y0=min(max_acts),
    x1=max(max_acts), y1=max(max_acts),
    line=dict(color="black", dash="dash")
)
fig.show()