In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f35304d6050>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [4]:
steering = sae6.W_dec[1062] * 56  # anger
# steering += sae12.W_dec[12312] * 10  # anger
steering = steering[None, None, :]

In [5]:
scales = [0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]

In [12]:
# what's the probability of selecting a token in the top 50?

def in_top(logits, top_n=100):
    probs = torch.softmax(logits, dim=-1)
    top_probs_sum = torch.sum(torch.topk(probs, top_n, dim=-1).values, dim=-1)
    return top_probs_sum.mean().item()

data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=8)

n_batches = 10
average_non_top = []
for scale in scales:
    print('scale', scale)
    total = 0
    for i, batch in enumerate(loader):
        with model.hooks(fwd_hooks=[(hp6, partial(patch_resid,
                                                    steering=steering,
                                                    c=scale,
                                                    pos=None,
                                                    ))]):
            total += in_top(model(batch["tokens"], return_type='logits', prepend_bos=False))
        if i + 1 == n_batches:
            break
    print(total/n_batches)
    average_non_top.append(total/n_batches)



scale 0
0.9339406132698059
scale 0.5
0.929868471622467
scale 1.0
0.9122872173786163
scale 1.5
0.8877916097640991
scale 2.0
0.8752076029777527
scale 2.5
0.8801261603832244
scale 3.0
0.8905377984046936
scale 3.5
0.8958958446979522
scale 4.0
0.8940239906311035
scale 4.5
0.8866267561912536
scale 5.0
0.8755286931991577


In [13]:
logits = model("I think", return_type='logits')
print(in_top(logits))

0.8190046548843384


In [14]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small", device='cpu')
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, gpt2.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=8)

n_batches = 10
total = 0
for i, batch in enumerate(loader):
    total += in_top(gpt2(batch["tokens"], return_type='logits', prepend_bos=False))
    if i + 1 == n_batches:
        break
print(total/n_batches)


Loaded pretrained model gpt2-small into HookedTransformer
0.8666989088058472


In [15]:
logits = gpt2("I think", return_type='logits')
print(in_top(logits))

0.725292444229126
