In [15]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x179d83cd0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]


Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
# intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [25]:
scales = [x for x in range(0, 120)]
prompt = "I think"

eval_criterion = "Mentions writing or anything related to writing" # writing
# eval_criterion = "Mentions London or anything related to London" # London
# eval_criterion = "Mentions wedding or anything related to wedding" # wedding

coherence_criterion = "Text is coherent, the grammar is correct."

In [33]:
def rate(scale, steering_vector):
    texts = generate(model,
        hooks=[(hp6, partial(patch_resid, steering=steering_vector, scale=scale))],
        max_new_tokens=25,
        prompt=prompt,
        batch_size=64,
        n_samples=128,
        )

    eval = evaluate_completions(texts, criterion=eval_criterion, prompt=prompt, verbose=False)
    coherence = evaluate_completions(texts, criterion=coherence_criterion, prompt=prompt, verbose=False)
    scores = [e['score'] for e in eval]
    coherence_scores = [e['score'] for e in coherence]
    return scores, coherence_scores

In [None]:
avg_scores = []
avg_coherence = []

for scale in tqdm(scales):
    scores, coherence = rate(scale, writing)
    avg_scores.append(sum(scores) / len(scores))
    avg_coherence.append(sum(coherence) / len(coherence))



In [None]:
px.line(x=scales, y=avg_scores, title="Writing steering vector")