In [53]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities

from functools import partial
from datasets import load_dataset
from tqdm import tqdm
from jaxtyping import Float

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# from steering.eval_utils import evaluate_completions
from steering.utils import get_activation_steering, get_sae_diff_steering, remove_sae_feats, text_to_sae_feats, top_activations
from steering.preview import preview_next_step

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x129a29890>

In [54]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)



Loaded pretrained model gpt2-small into HookedTransformer


In [55]:
hp_6 = tutils.get_act_name("resid_pre", 6)
sae_6 = get_gpt2_res_jb_saes(hp_6)[0][hp_6]
sae_6 = sae_6.to(model.W_E.device)

100%|██████████| 1/1 [00:01<00:00,  1.13s/it]


In [56]:
sae_feats = text_to_sae_feats(model, sae_6, hp_6, "I am so happy")
top_v, top_i = top_activations(sae_feats, 10)

print(top_i)
print(top_v[0,-1])

# anger feature = (10131, 28.0792), 6415
# happy feature = (20985, 32.62), (9995, 12.4177)

tensor([[[23123,   979,   316,  7496, 23111, 23373,  9088, 16196,  2039, 10423],
         [23409, 19151,  4422,  6144, 21687, 11355, 13648,  1781, 21952,  1622],
         [18490, 19117,  1622,  7574,   144, 21060, 15396,  1738, 14511, 19151],
         [ 4003, 23672,  2312,  7574,  1622,   396, 14732, 15396, 19136, 24191],
         [20985,  9995, 21393,  4492,  8120,  1738,  7574,  4512, 24191, 19136]]])
tensor([32.6204, 12.4177, 12.2635,  8.3271,  6.1568,  6.1488,  5.6456,  4.6021,
         3.0081,  2.9267])


In [57]:
def patch_position(
    value: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    steering_vectors: Float[torch.Tensor, "num d_model"],
    activations: Float[torch.Tensor, "num"],
    c: float,
    position: int
    
) -> Float[torch.Tensor, "batch pos d_head"]:
    # add all feature_vectors to vector
    for i, steering_vector in enumerate(steering_vectors):
      value[:, position, :] += steering_vector * activations[i] * c
    return value

In [58]:
features = [[20985, 32.62]]
steering_vectors = [sae_6.W_dec[feature[0]] for feature in features]

In [59]:
vocab_index = model.tokenizer.encode("happy")[0]
print(vocab_index)

# anger: 2564
# happy: 34191

34191


In [27]:
def generate(
    model: HookedTransformer, prompt: str, fwd_hooks=[], n_samples=5, max_length=20
):
    gen_texts = []
    with model.hooks(fwd_hooks=fwd_hooks):
        for _ in tqdm(range(n_samples)):
            output = model.generate(prompt,
                                    prepend_bos=True,
                                    use_past_kv_cache=False,
                                    max_new_tokens=max_length,
                                    verbose=False,
                                    )
            gen_texts.append(output)
    return gen_texts


    


In [52]:
hook = (
    hp_6,
    partial(
        patch_position, 
        steering_vectors=steering_vectors, 
        activations=[feature[1] for feature in features], 
        c=10,
        position=3
    )
)

fwd_hooks = [hook]
prompt = "I feel very"

generate(model, prompt, fwd_hooks=fwd_hooks)

100%|██████████| 5/5 [00:14<00:00,  2.92s/it]


['I feel very ending with my students getting an education. They got education on legal access, started school brings them to',
 "I feel very at ease. I'll eat mine\n\nand someday he'll be eating my cake or mixed with",
 'I feel very to be yelling that I am wearing the WWEwwkk from Wrestlemania XXII. Those WWE fans',
 'I feel very and proud to have worked hard for a lifelong career in film and television development as a Sundance Centre',
 'I feel very with one arm when I breastfeed and sound like a buzzing bear. To caregivers, my fetus looks']

In [32]:
preview_next_step(model, prompt, fwd_hooks=fwd_hooks, watch_logits=[2564, 34191])

Positions,Token,Act
1,and,16.932461
2,",",16.372375
3,towards,15.523688
4,at,15.224799
5,ing,15.043599
6,toward,14.765697
7,directed,13.932105
8,or,13.598452
9,ful,13.59123
10,-,13.402057


Positions,Token,Act
8877,anger,2.680283
