In [2]:
## Updated by gavento based on code by Sudarsh

from transformer_lens import HookedTransformer
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Load the pre-trained model
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B", device=device)
# gemma_model = HookedTransformer.from_pretrained("gemma-2b", device="mps")
torch.set_printoptions(threshold=1_000_000)
torch.set_printoptions(linewidth=1_000_000)

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model Qwen/Qwen2-0.5B into HookedTransformer


In [22]:
# Define the input text
input_text = '360, when you\'re in the mirror do you like what you see'

# Tokenize the input text
tokens = model.to_tokens(input_text)[:, 1:] # Skip BOS
print(tokens.shape, tokens)
print(model.tokenizer.convert_ids_to_tokens(tokens[0]))

# Get regular embeddings
reg_embeddings = model.embed(tokens)
print(reg_embeddings.shape)



SOFT_TOKENS = 1
MODEL_DIM = 896
L2REG = 0.5

def exec_model(model, first_tokens_embedding, tokens):
    residual, _tks, _spe, _attn = model.input_to_embed(tokens)

    skip = first_tokens_embedding.shape[1]
    both = torch.concat([first_tokens_embedding, residual[:, skip-1:-1, :]], axis=1)
    #print(first_tokens_embedding.shape, residual.shape, both.shape)
    return model(both, start_at_layer=0)

def predict(model, first_tokens_embedding, num_tokens):
    tokens = []
    skip = first_tokens_embedding.shape[1]
    for i in range(num_tokens):
        toks = torch.tensor(tokens, dtype=torch.long)
        residual, _tks, _spe, _attn = model.input_to_embed(toks)
        both = torch.concat([first_tokens_embedding.detach(), residual], axis=1)
        res = model(both, start_at_layer=0)[0]
        next_token = torch.argmax(res, axis=-1)[-1].item()  # Get the last token predicted
        tokens.append(next_token)
        probs = torch.softmax(res, axis=-1)
        #print(model.tokenizer.convert_ids_to_tokens(tokens))
        #print(residual.shape, toks.shape, both.shape, res.shape, probs.shape)
    
    return (model.tokenizer.decode(tokens),
            model.tokenizer.convert_ids_to_tokens(tokens),
            torch.tensor(tokens).cpu())  # Convert the final token list to a tensor

def token_alignment_loss(logits, tokens, first_tokens_embedding, alpha=1.0, beta=0.7, gamma=0.1):
    # Calculate the cross-entropy loss
    def l2(x): return torch.sum(x ** 2) ** 0.5

    ce_loss = F.cross_entropy(logits, tokens)
    
    # Calculate the token matching loss
    pred_tokens = torch.argmax(logits, dim=-1)
    match_loss = (pred_tokens != tokens).float().mean()
    
    # L2 regularization
    l2_loss = l2(first_tokens_embedding)
    
    # Combine the losses
    total_loss = alpha * match_loss + beta * ce_loss + gamma * l2_loss
    
    return total_loss
# 
# def nearest_neighbour_loss(x): 


torch.Size([1, 16]) tensor([[   18,    21,    15,    11,   979,   498,  2299,   304,   279, 17846,   653,   498,  1075,  1128,   498,  1490]], device='mps:0')
['3', '6', '0', ',', 'Ġwhen', 'Ġyou', "'re", 'Ġin', 'Ġthe', 'Ġmirror', 'Ġdo', 'Ġyou', 'Ġlike', 'Ġwhat', 'Ġyou', 'Ġsee']
torch.Size([1, 16, 896])


In [21]:
first_tokens_embedding = torch.tensor(np.random.normal(0.0, 768**(-0.5), size=(1, SOFT_TOKENS, MODEL_DIM)))
if device == "mps":
    first_tokens_embedding = first_tokens_embedding.to(torch.float32).to(device)

first_tokens_embedding = first_tokens_embedding.requires_grad_(True)

num_steps = 5100  # Number of optimization steps
# Define the optimizer for the first tokens' embedding
optimizer = torch.optim.Adam([first_tokens_embedding], lr=0.02, amsgrad=True)
# Learning Rate Scheduler
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.005 ** (1/num_steps), verbose=True)

loss_arr = []

def l2(x): return torch.sum(x ** 2) ** 0.5

last_corr = 0
lossahead = 7
lookahead = 20
for step in range(num_steps):
    # Zero the gradients
    optimizer.zero_grad()

    # Replace the first token's embedding in the reg_embeddings tensor
    logits = exec_model(model, first_tokens_embedding, tokens)
    #print(logits.shape)
    # Calculate the loss
    # loss = F.cross_entropy(logits.flatten(0, 1)[:last_corr+lossahead], tokens.flatten(0, 1)[:last_corr+lossahead]) + L2REG * l2(first_tokens_embedding)

    flattened_logits = logits.flatten(0, 1)[:last_corr+lossahead]
    flattened_tokens = tokens.flatten(0, 1)[:last_corr+lossahead]

    loss = token_alignment_loss(flattened_logits, flattened_tokens, first_tokens_embedding)
    
    # Backpropagate the loss with retain_graph=True
    loss.backward()

    # Optimize the first token's embedding
    optimizer.step()
    lr_scheduler.step()

    loss_arr.append(loss.item())

    # Print the loss for every 100 steps
    print(step)
    if step % 10 == 0:
        ps = predict(model, first_tokens_embedding, last_corr+lookahead)
        temp_corr = (ps[2][:min(last_corr+lookahead, tokens.shape[1])] == tokens[:, :last_corr+lookahead].cpu()).sum()
        if temp_corr > last_corr:
            last_corr += 1
        print(f"\nStep {step}, Correct={temp_corr}, Loss={loss.item()}/{last_corr}, L2={l2(first_tokens_embedding).detach().cpu()}, LR={lr_scheduler.get_last_lr()}, Pred={ps[0]!r}")
    
    # just comment out this part if you want to not add tokens
    # if step % 20 == 0 and step != 0:
    #     new_tokens_embedding = torch.tensor(np.random.normal(0.0, 768**(-0.5), size=(1, 1, MODEL_DIM)), dtype=torch.float32, requires_grad=True).to(device)
    #     print("ADDING NEW TOK")
    #     first_tokens_embedding = torch.cat([ new_tokens_embedding, first_tokens_embedding], dim=1)
    #     first_tokens_embedding = first_tokens_embedding.detach().requires_grad_(True)

    #     # Reinitialize the optimizer with the updated embedding
    #     optimizer = torch.optim.Adam([first_tokens_embedding], lr=0.02, amsgrad=True)
    #     lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.005 ** (1/num_steps), verbose=True)


plt.plot(loss_arr)

0

Step 0, Correct=0, Loss=6.224291801452637/0, L2=1.2750376462936401, LR=[0.019979233073936713], Pred='\n\n\n\n\n\n\n\n\n\n...\n...\n...\n\n...\n\n...\n...\n...\n...\n'
1
2
3
4
5
6
7
8
9
10

Step 10, Correct=0, Loss=9.398530006408691/0, L2=2.5316247940063477, LR=[0.019772746105942783], Pred='.nelelelelelelelelelelelelelelelelelelele'
11
12
13
14
15
16
17
18
19
20

Step 20, Correct=0, Loss=7.612310409545898/0, L2=3.1040899753570557, LR=[0.01956839319723899], Pred='\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'
ADDING NEW TOK
21
22
23
24
25
26
27
28
29
30

Step 30, Correct=0, Loss=7.127480983734131/0, L2=4.17234992980957, LR=[0.01979329840416818], Pred='entiation of the set of all $n\\times n$ matrices with real entries is a Ban'
31
32
33
34
35
36
37
38
39
40

Step 40, Correct=0, Loss=5.959875583648682/0, L2=4.637292385101318, LR=[0.019588733085822325], Pred=' entropy of a random variable\n\nLet $X$ be a random variable with a finite support. Let'
ADDING NEW TOK
41
42
43
44
45
46
47
48
49
50

KeyboardInterrupt: 

In [None]:
# The prediction runs only given the following tensor shaped (1,1,768):
print(first_tokens_embedding)
MAGIC = torch.tensor([[[ 0.0722,  0.0018, -0.0332,  0.0840,  0.0088,  0.0193, -0.0776,  0.0729,  0.0186, -0.0036, -0.0359,  0.1284, -0.0428,  0.0698,  0.0470, -0.0241, -0.0458, -0.0276,  0.0031, -0.0530, -0.1170, -0.0169, -0.0305,  0.0401, -0.0561, -0.0567,  0.0276,  0.0757, -0.0233,  0.1150, -0.0454,  0.0203,  0.0384,  0.0577, -0.0058, -0.0583,  0.0537, -0.0087, -0.0267, -0.0287,  0.0393, -0.0766, -0.0212, -0.0063, -0.0931,  0.0343, -0.1035,  0.0565,  0.0423,  0.0126,  0.0722, -0.0143, -0.0060,  0.0589,  0.0825, -0.0253, -0.0061,  0.0526, -0.0367,  0.0349, -0.0842, -0.0090,  0.0014,  0.0716,  0.0752, -0.0421,  0.0591, -0.0066,  0.0444,  0.0431,  0.0639, -0.0387,  0.0112, -0.0076,  0.1020,  0.0308, -0.1242,  0.0220, -0.1179, -0.0400, -0.0526,  0.0353, -0.0109, -0.0455, -0.0189,  0.0022,  0.0773,  0.0767, -0.0153,  0.0591, -0.0570,  0.0563, -0.0706, -0.0480,  0.0171, -0.0015, -0.0300, -0.0241,  0.0291, -0.0339, -0.0071,  0.0210, -0.0025, -0.0055,  0.0188,  0.0200, -0.0551,  0.0525,  0.0311, -0.0937, -0.0839, -0.1068,  0.0302, -0.0122,  0.0020, -0.0559, -0.0078, -0.0015, -0.0401,  0.0463, -0.0237,  0.0726,  0.1067, -0.0347, -0.0233,  0.0067,  0.0091, -0.0825,  0.0250,  0.0064,  0.0545, -0.0087, -0.0274,  0.1103,  0.0077,  0.0357, -0.0424,  0.0451, -0.1372, -0.0541,  0.0402,  0.0057,  0.0520,  0.0205,  0.0240, -0.0587,  0.0458,  0.0314,  0.0432,  0.0024, -0.0410,  0.0260, -0.0278, -0.0438,  0.0320, -0.0563, -0.0502,  0.0074, -0.0817,  0.0423,  0.0250, -0.0064,  0.0055,  0.0313,  0.0897,  0.0786, -0.0028,  0.0643,  0.0133,  0.0792, -0.0145, -0.0455,  0.0045,  0.0213,  0.0275,  0.0145, -0.0258, -0.0755, -0.0116,  0.0197,  0.0470, -0.0054, -0.0071, -0.0756, -0.0021, -0.0248,  0.0113, -0.0651,  0.1203, -0.1215, -0.0508,  0.0435, -0.0896, -0.0312,  0.0862,  0.0694,  0.0316,  0.0801, -0.0263,  0.0139,  0.0740, -0.0926,  0.0364,  0.0207, -0.0460,  0.0351,  0.0083,  0.0178,  0.0173, -0.0611,  0.0472,  0.0112, -0.0603, -0.0213, -0.1077,  0.0311, -0.0408, -0.0207, -0.0425, -0.0180,  0.0130, -0.0963,  0.0482,  0.0084, -0.0331,  0.0235, -0.0006, -0.0231, -0.0599, -0.0082, -0.0876,  0.0229, -0.1219, -0.0388,  0.0745, -0.0089,  0.0047, -0.0055, -0.0821, -0.0338, -0.0043, -0.0576,  0.0157,  0.0211,  0.0435, -0.0266, -0.0077,  0.0518, -0.0055, -0.0808,  0.0975, -0.0411,  0.0157,  0.0038,  0.0071,  0.0017, -0.0042, -0.0058,  0.0274, -0.1161, -0.0185,  0.0205,  0.1240, -0.0251,  0.0989,  0.0354,  0.0817, -0.0644, -0.0352, -0.1421, -0.0112, -0.0024, -0.0371, -0.0382,  0.0332, -0.0662, -0.0356,  0.0206, -0.1179, -0.0053, -0.0148, -0.0397,  0.0005,  0.0112,  0.0105,  0.0260, -0.0227,  0.0138,  0.0317, -0.0067,  0.0022,  0.0060, -0.0015, -0.0391, -0.0102, -0.0430,  0.0620, -0.0184,  0.0033,  0.0062, -0.0511,  0.0313,  0.0205, -0.0378, -0.0725, -0.0652,  0.0188,  0.0165,  0.0161, -0.0728,  0.0189,  0.0626, -0.0146, -0.0428,  0.1350,  0.0410, -0.0336,  0.1037, -0.0491,  0.0251, -0.0151, -0.0277, -0.0178, -0.0020,  0.0371, -0.0144,  0.0894,  0.0304, -0.0409,  0.1606,  0.0453, -0.0809,  0.0031, -0.0099,  0.0288, -0.0419, -0.0154, -0.0242, -0.0344,  0.0128,  0.0481, -0.0133,  0.0081, -0.0493, -0.0442,  0.0301,  0.0068, -0.0065,  0.0467, -0.0127,  0.0164, -0.0336, -0.0299, -0.1080,  0.0640,  0.0241,  0.0609,  0.0010, -0.1020,  0.1194,  0.1178,  0.0152, -0.0918, -0.0748,  0.0845, -0.0168,  0.0327, -0.0258, -0.0150, -0.0381, -0.0359,  0.0518, -0.0464,  0.0771,  0.0054,  0.0455, -0.0223,  0.0444,  0.0843, -0.0550, -0.0277, -0.0142,  0.0248, -0.0285, -0.0420, -0.0065,  0.0117, -0.0557, -0.0032, -0.1257, -0.0286,  0.0048, -0.0523,  0.0255,  0.0799,  0.0364,  0.0957, -0.0922, -0.0183,  0.0339,  0.0504,  0.0726,  0.0269, -0.0327,  0.0941,  0.0243, -0.0040,  0.0495, -0.0269,  0.0651,  0.0257,  0.0104, -0.0555,  0.0147,  0.0665, -0.0048, -0.0580,  0.0178,  0.0786,  0.0258, -0.0353,  0.0252,  0.0546,  0.0195, -0.0110, -0.0478,  0.0180,  0.0638,  0.0701,  0.0045,  0.0950,  0.0025, -0.0315, -0.0307,  0.0334, -0.0189, -0.0133, -0.0390, -0.0344,  0.0437, -0.0925,  0.0021,  0.0377,  0.0144, -0.0829, -0.0185,  0.0589, -0.0989, -0.0132, -0.0768,  0.0231,  0.0173, -0.0084, -0.0252, -0.0269, -0.0345, -0.0033,  0.0335,  0.0022,  0.0436,  0.0678, -0.0262,  0.0217,  0.0203, -0.0251, -0.0531, -0.0335, -0.0079,  0.0427,  0.0197, -0.0349,  0.0027,  0.0525, -0.0131, -0.0187, -0.1100,  0.0276, -0.0469, -0.0280,  0.0255,  0.0962, -0.0727,  0.0039, -0.0969,  0.0008,  0.0315,  0.0379,  0.0419, -0.0900, -0.0674,  0.0666,  0.0287,  0.0408, -0.0418,  0.0102,  0.0461, -0.0309,  0.0791, -0.0369,  0.1147, -0.0701, -0.0152,  0.0740,  0.0453, -0.0131, -0.0309,  0.0264,  0.0032,  0.0235, -0.1436,  0.0531, -0.0129, -0.0275, -0.0223,  0.0008, -0.0111,  0.0566,  0.0777, -0.0393,  0.0559,  0.0766,  0.1292, -0.0004,  0.1153,  0.0237, -0.0440, -0.0127,  0.0600,  0.0339,  0.0036, -0.0151, -0.0314,  0.0866, -0.0226, -0.0160, -0.0388, -0.0165,  0.0115,  0.0944, -0.0367, -0.0420, -0.0094, -0.0357, -0.0448, -0.0329,  0.0900,  0.0505, -0.0280,  0.0869,  0.0205, -0.0039, -0.0045,  0.0231,  0.0936, -0.0189, -0.0244,  0.0386, -0.0084, -0.0992, -0.0358, -0.0255, -0.0109, -0.0223,  0.0658,  0.0210, -0.0090,  0.0513,  0.0648, -0.0428,  0.0203, -0.0385,  0.0421, -0.0428, -0.0597, -0.0281,  0.1150, -0.0488, -0.0213,  0.0962,  0.0037, -0.0588, -0.0330, -0.0141,  0.0646, -0.0179, -0.0203, -0.0204,  0.0413, -0.0339,  0.0377, -0.0797,  0.0210,  0.0395, -0.0612, -0.0291, -0.0481,  0.0418,  0.0170, -0.0641, -0.0139,  0.0308,  0.0076,  0.0104,  0.0205, -0.0043,  0.0222, -0.0228,  0.0337, -0.0021,  0.0628, -0.0111,  0.0041,  0.0705, -0.0225, -0.0248,  0.0287,  0.0709,  0.0109, -0.0558, -0.0489, -0.0321,  0.0767,  0.0377, -0.0531,  0.0612, -0.0467,  0.0343,  0.0195,  0.0741,  0.0963, -0.0236,  0.0707, -0.0491, -0.0407, -0.0370,  0.0806,  0.0174, -0.0130, -0.0194,  0.0155,  0.0348,  0.0540, -0.0357,  0.0176,  0.0408,  0.0997,  0.0262, -0.0179,  0.0663,  0.0956,  0.0476, -0.0568, -0.0305,  0.0063,  0.0643,  0.0656, -0.0427, -0.1023,  0.1351, -0.0445, -0.0047, -0.0674, -0.0750,  0.0591, -0.0946,  0.0220,  0.0023,  0.0086, -0.0895,  0.0430, -0.0613,  0.0638, -0.0318,  0.0536,  0.0216,  0.0350, -0.0043, -0.0470,  0.0360,  0.0944,  0.0140,  0.0915,  0.0719,  0.0168, -0.0117, -0.0369, -0.0402, -0.0139, -0.0954, -0.0463, -0.0829,  0.0891, -0.0709,  0.0022, -0.0028,  0.0302, -0.1730,  0.0585,  0.0022,  0.0357,  0.0248,  0.0274,  0.0361, -0.0570,  0.0508,  0.0770, -0.0273,  0.0608,  0.0062, -0.0077, -0.0392,  0.0320, -0.0385, -0.0652,  0.0106,  0.0014, -0.0817, -0.1166, -0.0085,  0.0298, -0.0669,  0.0334, -0.0173, -0.0321,  0.0299, -0.0333, -0.0049,  0.0614, -0.0467,  0.0267,  0.0398,  0.0070,  0.0534, -0.0180, -0.0445,  0.0437, -0.0795, -0.0665, -0.0085, -0.0171, -0.0207,  0.0128,  0.0639, -0.0206,  0.0176,  0.0282,  0.0749,  0.0757, -0.0355, -0.0445, -0.0050, -0.0335, -0.0774, -0.0544,  0.0394, -0.0315,  0.0321,  0.0474, -0.0200,  0.0276,  0.0329, -0.0617, -0.0847, -0.0258, -0.0408,  0.0513,  0.0118,  0.0024]]], device='mps', dtype=torch.float32, requires_grad=True)
# Compare the following:
print(repr(predict(model, MAGIC, 90)[0]))
print(repr(input_text))

tensor([[[-6.0150e-04, -1.1581e-02, -3.2306e-02, -1.3725e-02, -2.2609e-02, -8.8896e-03, -1.9505e-02,  1.6679e-02,  2.7680e-04,  1.2192e-02, -2.7259e-02,  2.4730e-02,  2.8125e-03, -1.9682e-02,  7.6228e-03, -1.2064e-02, -1.2700e-02, -1.0188e-02, -2.9960e-02,  5.0544e-03, -1.8981e-02,  7.1290e-03, -3.6753e-02,  8.2953e-03,  1.0060e-02, -1.9799e-03, -1.0252e-02,  1.0263e-02,  3.0543e-02,  1.5026e-02, -2.7516e-03, -1.1322e-02, -2.3023e-02,  3.3208e-02,  9.7001e-03, -4.6102e-03, -1.4564e-03,  9.0939e-03,  2.7024e-02,  4.2478e-02, -3.4492e-03, -1.5862e-02, -2.0506e-02,  2.7238e-02,  9.4534e-03,  4.7789e-03, -9.4079e-03,  3.8492e-03, -8.3462e-03,  1.3931e-02,  8.5289e-03,  1.6567e-02,  5.5420e-03,  1.0487e-02,  1.8013e-02, -1.1775e-02, -2.1005e-03,  1.3054e-02,  4.5922e-03, -4.3402e-02, -1.8127e-02, -1.3480e-02,  2.7587e-03, -1.9453e-04,  1.3395e-02, -1.3129e-02,  1.1820e-02,  5.8265e-03,  2.0391e-03,  2.7841e-03,  2.7364e-02,  2.8516e-02, -2.6090e-02, -2.7527e-02, -1.6210e-03, -1.1307e-02, -2

  toks = torch.tensor(tokens, dtype=torch.long)


'Relative entropy is always a non-negative real number, with value 0 if and only if the two distributions in question are identical. It has diverse applications, both theoretical, such as characterising order, such as and independently, such, such, both, both, both, both, both, both, both, both, both, both, both, and both, and both, and both, and both, and both, and both,'
'Relative entropy is always a non-negative real number, with value 0 if and only if the two distributions in question are identical. It has diverse applications, both theoretical, such as characterizing the relative (Shannon) entropy in information systems, randomness in continuous time-series, and information gain when comparing statistical models of inference; and practical, such as applied statistics, fluid mechanics, neuroscience, bioinformatics, and machine learning.'
