In [15]:
import os
import sys
import json
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, get_loss, scores_2d, scores_clamp_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7fa20113ffa0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [18]:
def patch_resid(resid, hook, steering, c=1, pos=None):
    assert len(steering.shape) == 3 # [batch_size, sequence_length, d_model]

    if pos is None:
        # insert at all positions
        assert steering.shape[1] == 1
        resid[:, :, :] = resid[:, :, :] + c * steering
        return resid

    n_toks = min(resid.shape[1] - pos, steering.shape[1])
    if pos < resid.shape[1]:
        resid[:, pos:n_toks+pos, :] = resid[:, pos:n_toks+pos, :] + c * steering[:, :n_toks, :]
    
    return resid

In [4]:
hp6 = "blocks.6.hook_resid_post"

sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [7]:
anger = sae6.W_dec[1062]
wedding = sae6.W_dec[8406]
steer = anger * 50 + wedding * 50
steer = steer[None, None, :]

In [19]:
gens = generate(model, hp6, "I think", steer, scale=1, patch_fn=patch_resid,n_samples=10)
gens

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:13<00:00,  1.32s/it]


['I think that’s a pretty good and hilarious photo. I just had my anniversary but I’ was so mad. I know I',
 "I think I got sick from a virus last night. Don' was in my voice. I think he was mad that someone ruined their",
 'I think it will come down to the wedding between the couple. And if both are mad it. They will try to push their wedding',
 'I think there’s a lot of people in <em>The Walking Dead</em> and and even more that to the wedding dress.\n\n',
 'I think the one where over the last three years my father died because the wedding bells turned to to the the and at that point he',
 'I think you might be the first person I ever got at, so I was very angry. I will be! The day I wanted',
 "I think it's been a whole year since I last posted in this section!\n\nHow I reacted is pretty. The photos in",
 "I think about my mom.\n\nNot at your wedding (for she's just as upset), but with the anger in their eyes",
 "I think about this stuff sometimes.\n\nI do. I even think about my wife.

In [20]:
model.to_str_tokens("don't do that")

['<bos>', 'don', "'", 't', ' do', ' that']