In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f28dc4d5090>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp12 = "blocks.12.hook_resid_post"
sae12 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp12)
normalise_decoder(sae12)
activation_store = ActivationsStore.from_config(model, sae12.cfg)

sae6 = sae12.to(device)

Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/23032 [00:00<?, ?it/s]

In [None]:
# compute normalisation scale factor as per anthropic.
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
# scale it to have Expected value of l2 norm = sqrt(d_model)


In [27]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=4)

In [23]:
norm_averages = []
for i, batch in enumerate(loader):
    logits, cache = model.run_with_cache(batch["tokens"], prepend_bos=False, names_filter=hp12)
    acts = cache[hp12]
    norms = torch.norm(acts, dim=-1)
    norm_averages.append(norms.mean())
    if i == 20:
        break

norm_average = torch.stack(norm_averages).mean().item()

In [24]:
norm_average

208.02650451660156

In [25]:
target_norm = (model.cfg.d_model ** 0.5)
target_norm

45.254833995939045

In [26]:
scale = target_norm / norm_average
scale

0.21754359667341083

In [None]:
 # this scale factor is now part of normalise_decoder in utils.py