In [19]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder, text_to_sae_feats, top_activations
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb3bc6220b0>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [4]:
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

In [5]:
prompt_batch = model.to_tokens("I think", prepend_bos=True)
prompt_batch = prompt_batch.expand(16, -1)

with model.hooks([(hp6, partial(patch_resid, steering=london, scale=50))]):
    texts = model.generate(
        prompt_batch,
        use_past_kv_cache=True,
        max_new_tokens=29,
        top_k=50,
        top_p=0.3,
    )
texts = model.to_string(texts)
texts

  0%|          | 0/29 [00:00<?, ?it/s]

["<bos>I think the biggest challenge we have in today is when we're facing a global pandemic and in particular, in the United Kingdom, we've started",
 '<bos>I think you should check out the following site, it has a forum with hundreds of posts on just this topic.\n\nhttp://forums.bmwusa.',
 "<bos>I think I'll go with the black paint to go with my black rims\nWhat a fantastic looking car this is.\nI'm so fond",
 "<bos>I think it's ok for a girl!\n\nI liked to wear a dress, but i think that for a girl is a bit boring\n\nI",
 '<bos>I think it is a reasonable thing to be very worried about, especially if any other relatives who have the disease. One should discuss with the doctor and have',
 '<bos>I think these guys were the ones who went out and hired the band for the party. They were very professional. They took the hassle out a whole night',
 "<bos>I think they’d take them to court and get the money back\nI hope they'll be put to work in a meat factory\nI'",
 '<bos>I think this place i

In [6]:
model.to(torch.float16)

prompt_batch = model.to_tokens("I think", prepend_bos=True)
prompt_batch = prompt_batch.expand(16, -1)

with model.hooks([(hp6, partial(patch_resid, steering=london.to(torch.float16), scale=50))]):
    texts = model.generate(
        prompt_batch,
        use_past_kv_cache=True,
        max_new_tokens=29,
        top_k=50,
        top_p=0.3,
    )
texts = model.to_string(texts)
model.to(torch.float32)

texts

Changing model dtype to torch.float16


  0%|          | 0/29 [00:00<?, ?it/s]

Changing model dtype to torch.float32


["<bos>I think you'll have to leave the tubes with the lights still inside or the tube will come out on its own and hit the floor . If you",
 '<bos>I think I’m going to start off today by being very dramatic and emotional: in July London I saw the new collection from Vivienne Westwood at the',
 "<bos>I think there were two different times when our nation reached the breaking point, both during the pandemic and last year's racial uprising. There was one year",
 '<bos>I think at this point we are all aware of the popularity of the London look. This was a huge hit last summer during the Olympics, then in the',
 '<bos>I think the 034 and 58 both come with electric fans. What was on mine?\n\n\nThere are two fans fitted to the standard',
 '<bos>I think it’s fair to say that most people will not have heard of the West Coast of Britain; a stretch of land known around the world as',
 '<bos>I think I am ready to be on the other side of the planning experience with my wedding.\n\nLondon is callin

In [81]:
prompts = [
    "I think you'",
    "If she won'",
    "I think that'",
    "I think it'",
    "If you couldn'",
]

prompt = "I think you'"
prompt_tokens = model.to_tokens(prompts, prepend_bos=True)
correct_ids = model.to_tokens(["re", 't', 's', 's', 't'], prepend_bos=False)

# baseline
baseline = model(prompt_tokens)[:, -1, :]
baseline = F.softmax(baseline, dim=-1)
baseline = baseline[torch.arange(len(prompts)), correct_ids.squeeze()]

baseline

tensor([[   478],
        [235251],
        [235256],
        [235256],
        [235251]], device='cuda:0')
torch.Size([5])


tensor([0.5197, 0.9987, 0.9917, 0.9746, 0.9997], device='cuda:0')

In [83]:
def eval(ft_id, scale=80):
    with model.hooks([(hp6, partial(patch_resid, steering=sae6.W_dec[ft_id], scale=scale))]):
        logits = model(prompt_tokens)[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        probs = probs[torch.arange(len(prompts)), correct_ids.squeeze()]
    return probs
    
eval(10138, scale=100)

tensor([3.1380e-03, 9.3456e-05, 6.9293e-01, 8.4291e-01, 6.0839e-05],
       device='cuda:0')

In [84]:
evals_for_scales = [eval(10138, scale=scale) for scale in range(120)]
px.line(evals_for_scales)

In [85]:
evals_for_scales = [eval(1062, scale=scale) for scale in range(120)]
px.line(evals_for_scales)

In [86]:
evals_for_scales = [eval(8406, scale=scale) for scale in range(120)]
px.line(evals_for_scales)

In [87]:
evals_for_scales = [eval(2378, scale=scale) for scale in range(120)]
px.line(evals_for_scales)