In [54]:
import os
import sys
import json
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d, scores_clamp_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random


## Train an Almost LoRA

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device, dtype=torch.float16)


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [9]:
hp6 = "blocks.6.hook_resid_post"

In [55]:
def lora_hook(resid, hook):
    global A, B
    resid = resid.to(torch.float32)

    mid = resid @ A
    # print(mid)
    lora_out = mid @ B
    resid = resid + lora_out

    resid = resid.to(torch.float16)
    return resid

In [21]:
with open("good_anger_wedding.json") as f:
    data = json.load(f)

In [58]:
r = 3
A = nn.Parameter(0.1*torch.randn((model.cfg.d_model, r), device=device), requires_grad=True)
B = nn.Parameter(0.1*torch.randn((r, model.cfg.d_model), device=device), requires_grad=True)
optimizer = optim.Adam([A, B], lr=0.002)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    total_loss = 0
    random.shuffle(data)
    for i, text in enumerate(data):
        optimizer.zero_grad()

        with model.hooks(fwd_hooks=[(hp6, lora_hook)]):
            loss = model(text, return_type="loss", loss_per_token=True)
        loss = loss[:3].mean() # don't include bos and prompt
        
        loss.backward()
        optimizer.step()
        # print(f"Epoch {epoch}, iter {i}, loss {loss.item()}")
        total_loss += loss.item()
    print(f"Epoch {epoch}, avg loss {total_loss / len(data)}")

Epoch 0, avg loss 3.0414091042536535
Epoch 1, avg loss 2.7474087454331944
Epoch 2, avg loss 2.6638280434498958
Epoch 3, avg loss 2.6125860353601253
Epoch 4, avg loss 2.570153477296451


In [None]:
# un-lora'd loss is 3.66

In [72]:
# save the A and B matrices
# torch.save(A, "A.pt")
# torch.save(B, "B.pt")

A = torch.load("A.pt")
B = torch.load("B.pt")

In [60]:
sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [82]:
anger = sae6.W_dec[1062]
wedding = sae6.W_dec[8406]

anger_enc = sae6.W_enc[:, 1062]
wedding_enc = sae6.W_enc[:, 8406]

combo = anger + wedding
combo = combo / combo.norm()

In [79]:
Bn = B / B.norm(dim=-1, keepdim=True)
An = A / A.norm(dim=0, keepdim=True)

tensor([5.1754, 5.2067, 5.3686], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)

In [80]:
print(Bn @ anger)
print(Bn @ wedding)
print(Bn @ combo)

tensor([-0.2527,  0.1514,  0.3097], device='cuda:0', grad_fn=<MvBackward0>)
tensor([-0.2037,  0.0928,  0.2351], device='cuda:0', grad_fn=<MvBackward0>)
tensor([-0.2649,  0.1417,  0.3162], device='cuda:0', grad_fn=<MvBackward0>)


In [83]:
print(An.T @ anger_enc)
print(An.T @ wedding_enc)

tensor([ 0.0830, -0.0111,  0.0172], device='cuda:0', grad_fn=<MvBackward0>)
tensor([ 0.1281, -0.1181,  0.1041], device='cuda:0', grad_fn=<MvBackward0>)


In [92]:
gens = []
with torch.no_grad():
    for i in range(50):
        with model.hooks(fwd_hooks=[(hp6, lora_hook)]):
            gens.append(model.generate("I think", max_new_tokens=25, top_k=50, top_p=0.3, use_past_kv_cache=False, verbose=False))
print(gens)

["I think it's safe to say that the anniversary of the civil war is still being talked about!\n\nI was so angry and", 'I think that the main problem with this is that it was over the top and you wouldn the bride because it was at night.\n\n', 'I think it is a pretty serious moment, when a girl is angry over her husband. The couple are now not able to get together', 'I think the main problem with this scene is that the bride is a complete b***h. I don’s understand why it’', "I think he'd die because he's so upset over his son, you can hardly blame him. It was the only thing", 'I think all the women in this film want their husbands to be more successful at home than they are.\n\nIn this short film,', "I think I may have married a mad woman.\n \nWhen my wife goes off I snap at her. She says I'", 'I think sometimes people can be angry without reason. So when your husband is or the wife is mad, there is probably a lot of', 'I think the people at the wedding are a bit of a joke. We had 100% 

In [94]:

eval_1 = evaluate_completions(gens, criterion="Text is angry/upset/enraged/distressed or is about anger/hate etc.", prompt="I think", verbose=False)
eval_2 = evaluate_completions(gens, criterion="Text mentions wedding/marriage/engagement.", prompt="I think", verbose=False)
coherence = evaluate_completions(gens, criterion="Text is coherent, the grammar is correct.", prompt="I think", verbose=False)
scores_1 = [e['score'] for e in eval_1]
scores_2 = [e['score'] for e in eval_2]
coherence_scores = [e['score'] for e in coherence]

print(scores_1)
print(scores_2)
print(coherence_scores)


[6, 2, 7, 8, 7, 1, 8, 5, 6, 6, 7, 7, 1, 7, 1, 6, 4, 1, 1, 5, 8, 1, 5, 1, 6, 7, 4, 6, 3, 7, 1, 5, 4, 1, 5, 5, 8, 5, 6, 1, 3, 4, 3, 9, 7, 3, 4, 5, 7, 1]
[1, 6, 8, 8, 1, 4, 10, 6, 10, 9, 10, 1, 10, 8, 8, 8, 1, 10, 10, 2, 8, 1, 9, 10, 5, 10, 1, 9, 1, 7, 10, 10, 8, 7, 2, 10, 4, 9, 1, 8, 1, 7, 6, 1, 2, 9, 1, 10, 1, 8]
[5, 4, 7, 4, 5, 8, 5, 7, 4, 3, 3, 6, 6, 8, 3, 3, 4, 6, 3, 9, 4, 3, 3, 3, 3, 6, 6, 5, 4, 2, 4, 3, 3, 4, 4, 3, 3, 6, 5, 3, 6, 3, 5, 2, 5, 7, 3, 6, 7, 4]


In [99]:
rs1 = (np.mean(scores_1) -1)/9
rs2 = (np.mean(scores_2) -1)/9
rc = (np.mean(coherence_scores) -1)/9

print(rs1, rs2, rc)

0.40222222222222226 0.5711111111111111 0.3955555555555555


In [100]:
print('mult', rs1 * rs2 * rc)

mult 0.09086448285322359
