In [23]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
from jaxtyping import Float

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# from steering.eval_utils import evaluate_completions
from steering.utils import get_activation_steering, get_sae_diff_steering, remove_sae_feats, text_to_sae_feats, top_activations
from steering.preview import preview_next_step, generate

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x148879110>

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)



Loaded pretrained model gpt2-small into HookedTransformer


In [25]:
hp_6 = tutils.get_act_name("resid_pre", 6)
sae_6 = get_gpt2_res_jb_saes(hp_6)[0][hp_6]
sae_6 = sae_6.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.13s/it]


In [26]:
sae_feats = text_to_sae_feats(model, sae_6, hp_6, "I am so happy")
top_v, top_i = top_activations(sae_feats, 10)

print(top_i)
print(top_v[0,-1])

# anger feature = (10131, 28.0792), 6415
# happy feature = (20985, 32.62), (9995, 12.4177)

tensor([[[23123,   979,   316,  7496, 23111, 23373,  9088, 16196,  2039, 10423],
         [23409, 19151,  4422,  6144, 21687, 11355, 13648,  1781, 21952,  1622],
         [18490, 19117,  1622,  7574,   144, 21060, 15396,  1738, 14511, 19151],
         [ 4003, 23672,  2312,  7574,  1622,   396, 14732, 15396, 19136, 24191],
         [20985,  9995, 21393,  4492,  8120,  1738,  7574,  4512, 24191, 19136]]])
tensor([32.6204, 12.4177, 12.2635,  8.3271,  6.1568,  6.1488,  5.6456,  4.6021,
         3.0081,  2.9267])


In [27]:
def patch_position(
    value: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    steering_vectors: Float[torch.Tensor, "num d_model"],
    activations: Float[torch.Tensor, "num"],
    c: float,
    position: int
    
) -> Float[torch.Tensor, "batch pos d_head"]:
    # add all feature_vectors to vector
    for i, steering_vector in enumerate(steering_vectors):
      value[:, position, :] += steering_vector * activations[i] * c
    return value

In [32]:
def generate(
    model: HookedTransformer, prompt: str, fwd_hooks=[], n_samples=5, max_length=20
):
    gen_texts = []
    with model.hooks(fwd_hooks=fwd_hooks):
        for _ in tqdm(range(n_samples)):
            output = model.generate(
                prompt,
                prepend_bos=True,
                use_past_kv_cache=False,
                max_new_tokens=max_length,
                verbose=False,
            )
            gen_texts.append(output)
    return gen_texts


In [28]:
features = [[10131, 28.0792]]
steering_vectors = [sae_6.W_dec[feature[0]] for feature in features]

In [29]:
vocab_index = model.tokenizer.encode("happy")[0]
print(vocab_index)

# anger: 2564
# happy: 34191

34191


In [33]:
hook = (
    hp_6,
    partial(
        patch_position, 
        steering_vectors=steering_vectors, 
        activations=[feature[1] for feature in features], 
        c=10,
        position=-2
    )
)

fwd_hooks = [hook]
prompt = "I can't believe how"
print(["<endoftext>"] + list(model.tokenizer.tokenize(prompt)))

generate(model, prompt, fwd_hooks=fwd_hooks, n_samples=3)


['<endoftext>', 'I', 'Ġcan', "'t", 'Ġbelieve', 'Ġhow']


100%|██████████| 3/3 [00:09<00:00,  3.09s/it]


["I can't believe how good it went down there. Just when I finally thought the woman working in the hall might be having",
 "I can't believe how much I'm not actually interested in… to speak of which I feel is wholly appropriate for… shit",
 "I can't believe how many times I've gotten sunburns, sweating on my own, thinking that I may have to"]

In [34]:
preview_next_step(model, prompt, fwd_hooks=fwd_hooks, watch_logits=[2564, 34191])

Positions,Token,Act
1,much,14.461168
2,many,14.322128
3,little,13.198481
4,long,12.777117
5,badly,12.76373
6,bad,12.615878
7,the,12.552763
8,this,12.508451
9,stupid,12.288016
10,hard,12.170626


Positions,Token,Act
39592,anger,-3.15961
34434,happy,-1.718482
