In [1]:
# looking at what happens to the encoder when we inject a decoder vector.

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

import einops

from sae_lens import SAE
# from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
# from sae_lens import SparseAutoencoder, ActivationsStore

from steering.evals_utils import evaluate_completions, multi_criterion_evaluation
from steering.utils import normalise_decoder
from steering.patch import generate, scores_2d, patch_resid

# from sae_vis.data_config_classes import SaeVisConfig
# from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f242f62c820>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [4]:
hp6 = "blocks.6.hook_resid_post"

sae6, _, _ = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = hp6, # won't always be a hook point
    device = 'cpu'
)

sae6 = sae6.to(device)
normalise_decoder(sae6)

In [18]:
# intelligence = sae6.W_dec[10351]   # intelligence and genius
writing = sae6.W_dec[1058]  # writing
anger = sae6.W_dec[1062]  # anger
london = sae6.W_dec[10138]  # London
wedding = sae6.W_dec[8406]  # wedding
broad_wedding = sae6.W_dec[2378] # broad wedding

broad_poetry = sae6.W_dec[11067]  # broad poetry

In [29]:
steer = writing
scale = 100

In [13]:
batch_size = 8
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tutils.tokenize_and_concatenate(data, model.tokenizer, max_length=32)
tokenized_data = tokenized_data.shuffle(42)
loader = DataLoader(tokenized_data, batch_size=batch_size)

In [34]:
n_steps = 10

diffs = torch.zeros(sae6.cfg.d_sae, device=sae6.W_dec.device)

for batch_idx, batch in enumerate(loader):

    _, acts = model.run_with_cache(batch['tokens'], names_filter=hp6)
    acts = acts[hp6]
    acts = acts.reshape(-1, acts.shape[-1]) 

    pre_distribution = sae6.encode(acts)
    post_distribution = sae6.encode(acts + steer*scale)

    diff = post_distribution - pre_distribution
    diffs += diff.sum(dim=0)

    if batch_idx >= n_steps - 1:
        break

diffs /= (n_steps*32*8)

In [35]:
top_v, top_i = torch.topk(diffs, 10, dim=-1)
print(top_i)
print(top_v)


tensor([ 1058, 13838,  9203, 11338,  7050,  1988,   650, 10345,  2151,  1852],
       device='cuda:0')
tensor([94.1977,  5.8552,  4.6552,  2.8302,  2.3867,  2.0325,  1.8691,  1.7715,
         1.7508,  1.5424], device='cuda:0')


In [32]:
bottom_v, bottom_i = torch.topk(-diffs, 10, dim=-1)
print(bottom_i)
print(bottom_v)

tensor([14376, 16277, 10971,  8322, 12520,  3659,  6037, 16150,  2229,  9278],
       device='cuda:0')
tensor([1.0137, 0.6913, 0.4905, 0.3307, 0.3178, 0.3046, 0.3045, 0.2791, 0.2618,
        0.2464], device='cuda:0')
