In [1]:
# looking at what happens to the encoder when we inject a decoder vector.

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcf716d8400>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [4]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [5]:
# intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

broad_poetry = sae6.W_dec[11067]  # broad poetry

In [11]:
steer = broad_wedding
scale = 100

In [7]:
batch_size = 8
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=32)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=batch_size)

In [12]:
n_steps = 10

diffs = torch.zeros(sae6.cfg.d_sae, device=sae6.W_dec.device)

for batch_idx, batch in enumerate(loader):

    _, acts = model.run_with_cache(batch['tokens'], names_filter=hp6)
    acts = acts[hp6]
    acts = acts.reshape(-1, acts.shape[-1]) 

    pre_distribution = sae6.encode(acts)
    post_distribution = sae6.encode(acts + steer*scale)

    diff = post_distribution - pre_distribution
    diffs += diff.sum(dim=0)

    if batch_idx >= n_steps - 1:
        break

diffs /= (n_steps*32*8)

In [13]:
top_v, top_i = torch.topk(diffs, 10, dim=-1)
print(top_i)
print(top_v)


tensor([ 2378,  4115,  1302, 13791,   722, 12110,  7215,  8406, 13882,  9367],
       device='cuda:0')
tensor([92.2852, 11.4228,  9.0107,  8.4798,  6.7220,  6.1105,  5.6003,  5.4764,
         5.0357,  4.9481], device='cuda:0')


In [14]:
bottom_v, bottom_i = torch.topk(-diffs, 10, dim=-1)
print(bottom_i)
print(bottom_v)

tensor([11631, 15417, 15481, 11683,  4857,  3869, 14173,  2813,  7690, 12661],
       device='cuda:0')
tensor([2.7186, 1.8626, 1.5036, 0.8218, 0.7397, 0.5358, 0.4955, 0.4613, 0.4226,
        0.4091], device='cuda:0')


In [16]:
texts = generate(model, hooks=[(hp6, partial(patch_resid, steering=steer, scale=80))], prompt='I think', n_samples=32, batch_size=32)
texts

['I think a lot of you have heard this before and its no secret! but my heart honestly just loves the',
 'I think I am SO lucky that our wedding day was a day we got to be able to have- a',
 'I think you have a point that you should add into your photos. A sweet vintage or modern design is a',
 "I think it's a beautiful classic and a must do! It was a picture on their Pinterest board.",
 'I think this is one word that can be described perfectly by any word and that is what we all want.',
 'I think that the new trend is a "no photo booth! " The photos you get is an out side',
 'I think I have always love it when I had a picture of my picture in my instagram album. I love',
 'I think this is an absolutely stunning wedding day at <strong><em>the talented couple blog</em></strong>, all about',
 "I think it's absolutely adorable.\n\nYou girls at your new location and 1 hour! This is",
 'I think this is an amazing quote! I always love the moments when your wedding day comes but I have one',
