In [1]:
import os
import sys
import json
sys.path.append(os.path.abspath('..'))

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
from tqdm import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from sae_lens import SparseAutoencoder, ActivationsStore

from steering.eval_utils import evaluate_completions
from steering.utils import text_to_sae_feats, top_activations, normalise_decoder, get_activation_steering
from steering.patch import generate, get_scores_and_losses, patch_resid, get_loss, scores_2d, scores_clamp_2d

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random


## Train an Almost LoRA

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("gemma-2b", device=device, dtype=torch.float16)


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [3]:
hp6 = "blocks.6.hook_resid_post"

In [39]:
def lora_hook(resid, hook):
    global A, B, bias
    resid = resid.to(torch.float32)

    mid = resid @ A
    mid = mid + bias
    mid = torch.relu(mid)


    lora_out = mid @ B
    resid = resid + lora_out

    resid = resid.to(torch.float16)
    return resid

In [40]:
with open("data_anger_wedding.json") as f:
    data = json.load(f)

In [41]:
batch_size = 4
print(len(data))
print(len(data)/batch_size)

model.to_tokens(data[:4]).shape

812
203.0


torch.Size([4, 28])

In [49]:
r = 3
A = nn.Parameter(0.1*torch.randn((model.cfg.d_model, r), device=device), requires_grad=True)
B = nn.Parameter(0.1*torch.randn((r, model.cfg.d_model), device=device), requires_grad=True)
bias = nn.Parameter(torch.zeros((r), device=device), requires_grad=True)
optimizer = optim.Adam([A, B, bias], lr=0.002)
criterion = nn.CrossEntropyLoss()

for epoch in range(8):
    total_loss = 0
    random.shuffle(data)
    for i in range(len(data)//batch_size):
        optimizer.zero_grad()

        tokens = model.to_tokens(data[i*batch_size:(i+1)*batch_size])

        with model.hooks(fwd_hooks=[(hp6, lora_hook)]):
            loss = model(tokens, return_type="loss", loss_per_token=True)
        loss = loss[:3].mean() # don't include bos and prompt
        
        loss.backward()
        optimizer.step()
        # print(f"Epoch {epoch}, iter {i}, loss {loss.item()}")
        total_loss += loss.item()
    print(f"Epoch {epoch}, avg loss {(total_loss * batch_size) / len(data)}")

Epoch 0, avg loss 3.123248922413793
Epoch 1, avg loss 2.7769492764778323
Epoch 2, avg loss 2.7107739378078817
Epoch 3, avg loss 2.68922702432266
Epoch 4, avg loss 2.651006388546798
Epoch 5, avg loss 2.6264913023399017
Epoch 6, avg loss 2.612732835591133
Epoch 7, avg loss 2.5963285098522166


In [50]:
# un-lora'd loss is 3.66

In [62]:
# save the A and B matrices
# torch.save(A, "A.pt")
# torch.save(B, "B.pt")

# A = torch.load("A.pt")
# B = torch.load("B.pt")

In [52]:
sae6 = SparseAutoencoder.from_pretrained("gemma-2b-res-jb", hp6)
normalise_decoder(sae6, scale_input=False)
sae6 = sae6.to(device)

In [53]:
anger = sae6.W_dec[1062]
wedding = sae6.W_dec[8406]

anger_enc = sae6.W_enc[:, 1062]
wedding_enc = sae6.W_enc[:, 8406]

combo = anger + wedding
combo = combo / combo.norm()

In [54]:
Bn = B / B.norm(dim=-1, keepdim=True)
An = A / A.norm(dim=0, keepdim=True)

In [55]:
print(bias)

Parameter containing:
tensor([-0.0284, -0.0371, -0.0484], device='cuda:0', requires_grad=True)


In [56]:
print(Bn @ anger)
print(Bn @ wedding)
print(Bn @ combo)

tensor([0.3209, 0.2191, 0.2593], device='cuda:0', grad_fn=<MvBackward0>)
tensor([0.2862, 0.1242, 0.2174], device='cuda:0', grad_fn=<MvBackward0>)
tensor([0.3523, 0.1992, 0.2766], device='cuda:0', grad_fn=<MvBackward0>)


In [57]:
print(An.T @ anger_enc)
print(An.T @ wedding_enc)

tensor([-0.0188, -0.0593,  0.0821], device='cuda:0', grad_fn=<MvBackward0>)
tensor([-0.1267, -0.1522, -0.0045], device='cuda:0', grad_fn=<MvBackward0>)


In [58]:
gens = []
with torch.no_grad():
    for i in range(50):
        with model.hooks(fwd_hooks=[(hp6, lora_hook)]):
            gens.append(model.generate("I think", max_new_tokens=25, top_k=50, top_p=0.3, use_past_kv_cache=False, verbose=False))
print(gens)

['I think I’m angry because he is the one who is holding the gun. (I’s my wedding ceremony.)\n\nAnd', "I think I've been angry at my husband before during the past 15 years.\n\nIn fact, some of those weddings", 'I think the guy got the wrong wife. First wife and I divorced after 18 years. I wasn. At her and took', 'I think there was a reason for it when I was at the chapel, and it still was the reason when I went up.\n\n', 'I think her is pretty angry with you. I would be upset. If we were angry over a loss my wife would be over it', 'I think my ex and wife turned into a fight over a fight about the wedding that never happened. It happened two years after and my', 'I think it’s the wedding at the same time and my own wedding. I think I would be furious over the fact that my', "I think the wife who started the argument is the only wife that cares if the wife started the argument. I don't think she", 'I think there will be no fight scene in the movie when the wedding scene is going on\ni

In [59]:

eval_1 = evaluate_completions(gens, criterion="Text is angry/upset/enraged/distressed or is about anger/hate etc.", prompt="I think", verbose=False)
eval_2 = evaluate_completions(gens, criterion="Text mentions wedding/marriage/engagement.", prompt="I think", verbose=False)
coherence = evaluate_completions(gens, criterion="Text is coherent, the grammar is correct.", prompt="I think", verbose=False)
scores_1 = [e['score'] for e in eval_1]
scores_2 = [e['score'] for e in eval_2]
coherence_scores = [e['score'] for e in coherence]

print(scores_1)
print(scores_2)
print(coherence_scores)


[8, 6, 7, 1, 6, 6, 7, 3, 6, 7, 1, 6, 7, 7, 2, 6, 1, 7, 6, 1, 2, 1, 1, 6, 8, 7, 7, 6, 6, 6, 5, 4, 4, 7, 8, 7, 8, 1, 7, 1, 7, 6, 2, 3, 2, 7, 6, 2, 5, 3]
[7, 7, 6, 2, 3, 8, 10, 5, 8, 1, 10, 10, 8, 6, 5, 1, 10, 10, 4, 5, 5, 7, 8, 8, 7, 10, 1, 1, 5, 1, 8, 8, 8, 6, 7, 7, 10, 10, 8, 10, 10, 8, 9, 8, 9, 10, 10, 7, 7, 8]
[3, 4, 3, 7, 4, 4, 3, 4, 5, 7, 4, 3, 4, 2, 3, 5, 3, 3, 3, 3, 3, 3, 5, 6, 4, 4, 4, 5, 3, 4, 3, 5, 4, 6, 7, 4, 5, 4, 3, 6, 7, 3, 4, 4, 4, 4, 4, 4, 2, 3]


In [60]:
rs1 = (np.mean(scores_1) -1)/9
rs2 = (np.mean(scores_2) -1)/9
rc = (np.mean(coherence_scores) -1)/9

print(rs1, rs2, rc)

0.43555555555555553 0.66 0.3422222222222222


In [61]:
print('mult', rs1 * rs2 * rc)

mult 0.09837748148148147
