In [1]:
## Updated by gavento based on code by Sudarsh

from transformer_lens import HookedTransformer
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.neighbors import BallTree

# Load the pre-trained model
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B", device=device)
# gemma_model = HookedTransformer.from_pretrained("gemma-2b", device="mps")
torch.set_printoptions(threshold=1_000_000)
torch.set_printoptions(linewidth=1_000_000)



Loaded pretrained model Qwen/Qwen2-0.5B into HookedTransformer


In [2]:
# Define the input text
input_text = """Consensus is lacking among historians with regard to what the actual events surrounding this event. The portrayal of Ahuntisc and Viel as martyrs in popular culture is objected to by those researchers who reject the notion that they were murdered. The ethnic identity and the actual phonetic name of the missionary referred to as Ahuntsic are also not universally agreed upon by authors. The deaths of Ahuntsic and the Recollet Father Nicholas Viel in a single event have been commemorated as martyrs with a statue at the Church of La Visitation-de-la-Bienheureuse-Vierge-Marie, a painting by Georges Delfosse housed at the Cathedral-Basilica of Mary in Montreal, and a cross memorial erected at Parc de lÎle-de-la-Visitation. According to the Martyrologe des Recollets, he was attacked and drowned for his faith and work as a missionary, along with father Viel who is publicly regarded as the first Canadian martyr, by three Huron Indians who were enemies of Christianity. Except for the intent of accompanying Viel on his trip to Quebec, records of the early life of Ahuntsic or his missionary work prior to the drowning event are not found. Notably and with regard to references from that incident, there is some debate surrounding the pronunciation and associated phonetic spelling of his name as well as his ethnic identity from the limited mention of his existence in historical accounts. There are at least three different pronunciations, and consequently three different phonetic spellings, used by historical authors for the name for the chronicled neophyte missionary Ahuntsic, although the etymology and meaning for all three are the same as they are basically variations of understanding for the same Huron expression. Several historians alternately refer in writing to the missionary Ahuntsic as “Auhaitsique,” which is translated as 'little fish', as the actual name or at least an alternative synonymous with Ahunstic. Father Le Clercq who was an associate of Viel, however, used Auhaustic as the written form of the name and Father Arthur Edward Jones, who in his 1909 book Old Huronia, maintained that the popular Ahuntsic pronunciation did not exist in the Huron language and that it is rather a degeneration of the correct sounding term intended to mean 'little fish' from the Huron dialect. The ethnic identity of Ahuntsic as a Huron converted to Christianity by Father Viel has also been contested by some historians who counter that Ahuntsic was actually a young French assistant who was given a Huron name. This opinion was based on an account written by Brother Sagard who was sent to Canada with Viel. Those who argue that Ahuntsic has a Native American heritage also use Sagard's history but with a different interpretation. Many of those arguing for a French Ahuntsic also reject the claim that Ahuntsic and Viel were assassinated. At least one novel entitled The Conquest of Canada: A Novel of Discovery written by Wendel Messer depicts 17th century dialogue with a native named Ahuntsic, though the book is admittedly a blend of history and fiction according to the foreword."""

# Tokenize the input text
tokens = model.to_tokens(input_text)[:, 1:] # Skip BOS
print(tokens.shape, tokens)
print(model.tokenizer.convert_ids_to_tokens(tokens[0]))

# 
# embeddings_tree = BallTree(W_E.numpy(), leaf_size=1)

SOFT_TOKENS = 5
MODEL_DIM = 896
L2REG = 0.5

def exec_model(model, first_tokens_embedding, tokens):
    residual, _tks, _spe, _attn = model.input_to_embed(tokens)

    skip = first_tokens_embedding.shape[1]
    both = torch.concat([first_tokens_embedding, residual], axis=1)
    #print(first_tokens_embedding.shape, residual.shape, both.shape)
    return model(both, start_at_layer=0)

def predict(model, first_tokens_embedding, num_tokens):
    tokens = []
    skip = first_tokens_embedding.shape[1]
    for i in range(num_tokens):
        toks = torch.tensor(tokens, dtype=torch.long)
        residual, _tks, _spe, _attn = model.input_to_embed(toks)
        both = torch.concat([first_tokens_embedding.detach(), residual], axis=1)
        res = model(both, start_at_layer=0)[0]
        next_token = torch.argmax(res, axis=-1)[-1].item()  # Get the last token predicted
        tokens.append(next_token)
        probs = torch.softmax(res, axis=-1)
        #print(model.tokenizer.convert_ids_to_tokens(tokens))
        #print(residual.shape, toks.shape, both.shape, res.shape, probs.shape)
    
    return (model.tokenizer.decode(tokens),
            model.tokenizer.convert_ids_to_tokens(tokens),
            torch.tensor(tokens).cpu())  # Convert the final token list to a tensor

def nearest_neighbour_loss(x, ball_tree, W_E): 
    # Get the nearest neighbour
    dist, ind = ball_tree.query(x.detach().cpu().numpy(), k=1)
    nearest_neighbour = W_E[ind].to(device)

    # then we calculate the distance between x and its nearest neighbour
    distance = torch.norm(x - nearest_neighbour, p=2)
    return distance

def token_alignment_loss(logits, tokens, first_tokens_embedding, alpha=1.0, beta=0.7, kappa=0.1, gamma=0.2):
    # Calculate the cross-entropy loss
    def l2(x): return torch.sum(x ** 2) ** 0.5

    logits = logits[SOFT_TOKENS - 1:-1, :]

    ce_loss = F.cross_entropy(logits, tokens)
    
    # Calculate the token matching loss
    pred_tokens = torch.argmax(logits, dim=-1)
    match_loss = (pred_tokens != tokens).float().mean()
    
    # L2 regularization
    l2_loss = l2(first_tokens_embedding)

    # # do nearest neighbour loss for each token in first_tokens_embedding
    # nn_loss = 0
    # for i in range(first_tokens_embedding.shape[1]):
    #     nn_loss += nearest_neighbour_loss(first_tokens_embedding[:, i, :], embeddings_tree, W_E)
    
    # Combine the losses
    total_loss = alpha * match_loss + beta * ce_loss + gamma * l2_loss #+ kappa * nn_loss

    return total_loss



torch.Size([1, 652]) tensor([[15220, 13626,   374, 31061,  4221, 50701,   448,  5250,   311,  1128,   279,  5042,  4357, 14590,   419,  1538,    13,   576, 73933,   315, 16366,  3850,  3427,   323, 11401,   301,   438, 59349,  5428,   304,  5411,  7674,   374, 75779,   311,   553,  1846, 11811,   879,  7850,   279, 22240,   429,   807,  1033, 31385,    13,   576, 21551,  9569,   323,   279,  5042, 50823,  5298,   829,   315,   279, 72898, 13862,   311,   438, 16366, 36940,   292,   525,  1083,   537, 60428,  7230,  5193,   553, 12014,    13,   576, 16375,   315, 16366, 36940,   292,   323,   279,  4067,   337,  1149, 20322, 39696, 11401,   301,   304,   264,  3175,  1538,   614,  1012, 80054,   657,   438, 59349,  5428,   448,   264, 34272,   518,   279,  9257,   315,  4929,  7656,  7556,  6810, 52826,  7671,  3591, 74776,   810, 19625, 86003, 47435,   645,    11,   264, 18824,   553, 94259,   422,   490,   436,   325, 51158,   518,   279, 56729,  7671, 29049,  3001,   315, 10244,   30

In [4]:
import math

first_tokens_embedding = torch.randn(size=(1, SOFT_TOKENS, MODEL_DIM), dtype=torch.float32, device=device)
if device == "cuda": 
    first_tokens_embedding = first_tokens_embedding.cuda()
if device == "mps":
    first_tokens_embedding = first_tokens_embedding.to(torch.float32).to(device)

first_tokens_embedding = first_tokens_embedding.requires_grad_(True)

num_steps = 1500  # Number of optimization steps
# Define the optimizer for the first tokens' embedding
optimizer = torch.optim.Adam([first_tokens_embedding], lr=0.02 * SOFT_TOKENS, amsgrad=True)
# Learning Rate Scheduler
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.005 ** (1/num_steps), verbose=True)

loss_arr = []

def l2(x): return torch.sum(x ** 2) ** 0.5

last_corr = 0
max_corr = 0
lossahead = 5 + SOFT_TOKENS * 2
lookahead = 20
kappa = 0.1
alpha = 0.0
beta = 1.0
for step in range(num_steps):
    # Zero the gradients
    optimizer.zero_grad()

    # Replace the first token's embedding in the reg_embeddings tensor
    logits = exec_model(model, first_tokens_embedding, tokens)
    #print(logits.shape)
    # Calculate the loss
    # loss = F.cross_entropy(logits.flatten(0, 1)[:last_corr+lossahead], tokens.flatten(0, 1)[:last_corr+lossahead]) + L2REG * l2(first_tokens_embedding)

    flattened_logits = logits.flatten(0, 1)[:last_corr+lookahead+SOFT_TOKENS]
    flattened_tokens = tokens.flatten(0, 1)[:last_corr+lookahead]

    loss = token_alignment_loss(flattened_logits, flattened_tokens, first_tokens_embedding, alpha=0.0, beta=1.0)
    
    # Backpropagate the loss with retain_graph=True
    loss.backward()

    # Optimize the first token's embedding
    optimizer.step()
    lr_scheduler.step()

    loss_arr.append(loss.item())

    # Print the loss for every 100 steps
    print(step)
    if step % 10 == 0:
        with torch.no_grad():  # Disable gradient computation for prediction

            ps = predict(model, first_tokens_embedding, last_corr+lookahead)
            temp_corr = (ps[2][:min(last_corr+lookahead, tokens.shape[1])] == tokens[:, :last_corr+lookahead].cpu()).sum()
            max_corr = max(temp_corr, max_corr)
            if temp_corr > last_corr:
                last_corr += math.ceil((temp_corr - last_corr) / 4)
            print(f"\nStep {step}, Maximum Correct={max_corr}, Correct={temp_corr}, Loss={loss.item()}/{last_corr}, L2={l2(first_tokens_embedding).detach().cpu()}, LR={lr_scheduler.get_last_lr()}, Pred={ps[0]!r}")
    

    # just comment out this part if you want to not add tokens
    # if step % 20 == 0 and step != 0:
    #     new_tokens_embedding = torch.tensor(np.random.normal(0.0, 768**(-0.5), size=(1, 1, MODEL_DIM)), dtype=torch.float32, requires_grad=True).to(device)
    #     print("ADDING NEW TOK")
    #     first_tokens_embedding = torch.cat([ new_tokens_embedding, first_tokens_embedding], dim=1)
    #     first_tokens_embedding = first_tokens_embedding.detach().requires_grad_(True)

    #     # Reinitialize the optimizer with the updated embedding
    #     optimizer = torch.optim.Adam([first_tokens_embedding], lr=0.02, amsgrad=True)
    #     lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.005 ** (1/num_steps), verbose=True)


plt.plot(loss_arr)

0

Step 0, Maximum Correct=0, Correct=0, Loss=21.244489669799805/0, L2=65.1295394897461, LR=[0.0996474019343147], Pred=' I have been working on a project for a few weeks now. I have been working on a project'
1
2
3
4
5
6
7
8
9
10

Step 10, Maximum Correct=4, Correct=4, Loss=12.794429779052734/1, L2=50.631568908691406, LR=[0.09618908201212216], Pred='Consensus is lacking. The only thing that seems to be agreed upon is that the woman was a'
11
12
13
14
15
16
17
18
19
20

Step 20, Maximum Correct=17, Correct=17, Loss=8.271393775939941/5, L2=38.99287414550781, LR=[0.0928507850554267], Pred="Consensus is lacking among historians with regards to what the actual events surrounding this event. The 'History'"
21
22
23
24
25
26
27
28
29
30

Step 30, Maximum Correct=17, Correct=0, Loss=9.240863800048828/5, L2=34.93156051635742, LR=[0.08962834559875058], Pred='UES is evident from the depiction of the individuals who made this statement. The description of the individuals is flawed. The description

In [4]:
# The prediction runs only given the following tensor shaped (1,1,768):
print(first_tokens_embedding)
MAGIC = torch.tensor([[[ 0.0722,  0.0018, -0.0332,  0.0840,  0.0088,  0.0193, -0.0776,  0.0729,  0.0186, -0.0036, -0.0359,  0.1284, -0.0428,  0.0698,  0.0470, -0.0241, -0.0458, -0.0276,  0.0031, -0.0530, -0.1170, -0.0169, -0.0305,  0.0401, -0.0561, -0.0567,  0.0276,  0.0757, -0.0233,  0.1150, -0.0454,  0.0203,  0.0384,  0.0577, -0.0058, -0.0583,  0.0537, -0.0087, -0.0267, -0.0287,  0.0393, -0.0766, -0.0212, -0.0063, -0.0931,  0.0343, -0.1035,  0.0565,  0.0423,  0.0126,  0.0722, -0.0143, -0.0060,  0.0589,  0.0825, -0.0253, -0.0061,  0.0526, -0.0367,  0.0349, -0.0842, -0.0090,  0.0014,  0.0716,  0.0752, -0.0421,  0.0591, -0.0066,  0.0444,  0.0431,  0.0639, -0.0387,  0.0112, -0.0076,  0.1020,  0.0308, -0.1242,  0.0220, -0.1179, -0.0400, -0.0526,  0.0353, -0.0109, -0.0455, -0.0189,  0.0022,  0.0773,  0.0767, -0.0153,  0.0591, -0.0570,  0.0563, -0.0706, -0.0480,  0.0171, -0.0015, -0.0300, -0.0241,  0.0291, -0.0339, -0.0071,  0.0210, -0.0025, -0.0055,  0.0188,  0.0200, -0.0551,  0.0525,  0.0311, -0.0937, -0.0839, -0.1068,  0.0302, -0.0122,  0.0020, -0.0559, -0.0078, -0.0015, -0.0401,  0.0463, -0.0237,  0.0726,  0.1067, -0.0347, -0.0233,  0.0067,  0.0091, -0.0825,  0.0250,  0.0064,  0.0545, -0.0087, -0.0274,  0.1103,  0.0077,  0.0357, -0.0424,  0.0451, -0.1372, -0.0541,  0.0402,  0.0057,  0.0520,  0.0205,  0.0240, -0.0587,  0.0458,  0.0314,  0.0432,  0.0024, -0.0410,  0.0260, -0.0278, -0.0438,  0.0320, -0.0563, -0.0502,  0.0074, -0.0817,  0.0423,  0.0250, -0.0064,  0.0055,  0.0313,  0.0897,  0.0786, -0.0028,  0.0643,  0.0133,  0.0792, -0.0145, -0.0455,  0.0045,  0.0213,  0.0275,  0.0145, -0.0258, -0.0755, -0.0116,  0.0197,  0.0470, -0.0054, -0.0071, -0.0756, -0.0021, -0.0248,  0.0113, -0.0651,  0.1203, -0.1215, -0.0508,  0.0435, -0.0896, -0.0312,  0.0862,  0.0694,  0.0316,  0.0801, -0.0263,  0.0139,  0.0740, -0.0926,  0.0364,  0.0207, -0.0460,  0.0351,  0.0083,  0.0178,  0.0173, -0.0611,  0.0472,  0.0112, -0.0603, -0.0213, -0.1077,  0.0311, -0.0408, -0.0207, -0.0425, -0.0180,  0.0130, -0.0963,  0.0482,  0.0084, -0.0331,  0.0235, -0.0006, -0.0231, -0.0599, -0.0082, -0.0876,  0.0229, -0.1219, -0.0388,  0.0745, -0.0089,  0.0047, -0.0055, -0.0821, -0.0338, -0.0043, -0.0576,  0.0157,  0.0211,  0.0435, -0.0266, -0.0077,  0.0518, -0.0055, -0.0808,  0.0975, -0.0411,  0.0157,  0.0038,  0.0071,  0.0017, -0.0042, -0.0058,  0.0274, -0.1161, -0.0185,  0.0205,  0.1240, -0.0251,  0.0989,  0.0354,  0.0817, -0.0644, -0.0352, -0.1421, -0.0112, -0.0024, -0.0371, -0.0382,  0.0332, -0.0662, -0.0356,  0.0206, -0.1179, -0.0053, -0.0148, -0.0397,  0.0005,  0.0112,  0.0105,  0.0260, -0.0227,  0.0138,  0.0317, -0.0067,  0.0022,  0.0060, -0.0015, -0.0391, -0.0102, -0.0430,  0.0620, -0.0184,  0.0033,  0.0062, -0.0511,  0.0313,  0.0205, -0.0378, -0.0725, -0.0652,  0.0188,  0.0165,  0.0161, -0.0728,  0.0189,  0.0626, -0.0146, -0.0428,  0.1350,  0.0410, -0.0336,  0.1037, -0.0491,  0.0251, -0.0151, -0.0277, -0.0178, -0.0020,  0.0371, -0.0144,  0.0894,  0.0304, -0.0409,  0.1606,  0.0453, -0.0809,  0.0031, -0.0099,  0.0288, -0.0419, -0.0154, -0.0242, -0.0344,  0.0128,  0.0481, -0.0133,  0.0081, -0.0493, -0.0442,  0.0301,  0.0068, -0.0065,  0.0467, -0.0127,  0.0164, -0.0336, -0.0299, -0.1080,  0.0640,  0.0241,  0.0609,  0.0010, -0.1020,  0.1194,  0.1178,  0.0152, -0.0918, -0.0748,  0.0845, -0.0168,  0.0327, -0.0258, -0.0150, -0.0381, -0.0359,  0.0518, -0.0464,  0.0771,  0.0054,  0.0455, -0.0223,  0.0444,  0.0843, -0.0550, -0.0277, -0.0142,  0.0248, -0.0285, -0.0420, -0.0065,  0.0117, -0.0557, -0.0032, -0.1257, -0.0286,  0.0048, -0.0523,  0.0255,  0.0799,  0.0364,  0.0957, -0.0922, -0.0183,  0.0339,  0.0504,  0.0726,  0.0269, -0.0327,  0.0941,  0.0243, -0.0040,  0.0495, -0.0269,  0.0651,  0.0257,  0.0104, -0.0555,  0.0147,  0.0665, -0.0048, -0.0580,  0.0178,  0.0786,  0.0258, -0.0353,  0.0252,  0.0546,  0.0195, -0.0110, -0.0478,  0.0180,  0.0638,  0.0701,  0.0045,  0.0950,  0.0025, -0.0315, -0.0307,  0.0334, -0.0189, -0.0133, -0.0390, -0.0344,  0.0437, -0.0925,  0.0021,  0.0377,  0.0144, -0.0829, -0.0185,  0.0589, -0.0989, -0.0132, -0.0768,  0.0231,  0.0173, -0.0084, -0.0252, -0.0269, -0.0345, -0.0033,  0.0335,  0.0022,  0.0436,  0.0678, -0.0262,  0.0217,  0.0203, -0.0251, -0.0531, -0.0335, -0.0079,  0.0427,  0.0197, -0.0349,  0.0027,  0.0525, -0.0131, -0.0187, -0.1100,  0.0276, -0.0469, -0.0280,  0.0255,  0.0962, -0.0727,  0.0039, -0.0969,  0.0008,  0.0315,  0.0379,  0.0419, -0.0900, -0.0674,  0.0666,  0.0287,  0.0408, -0.0418,  0.0102,  0.0461, -0.0309,  0.0791, -0.0369,  0.1147, -0.0701, -0.0152,  0.0740,  0.0453, -0.0131, -0.0309,  0.0264,  0.0032,  0.0235, -0.1436,  0.0531, -0.0129, -0.0275, -0.0223,  0.0008, -0.0111,  0.0566,  0.0777, -0.0393,  0.0559,  0.0766,  0.1292, -0.0004,  0.1153,  0.0237, -0.0440, -0.0127,  0.0600,  0.0339,  0.0036, -0.0151, -0.0314,  0.0866, -0.0226, -0.0160, -0.0388, -0.0165,  0.0115,  0.0944, -0.0367, -0.0420, -0.0094, -0.0357, -0.0448, -0.0329,  0.0900,  0.0505, -0.0280,  0.0869,  0.0205, -0.0039, -0.0045,  0.0231,  0.0936, -0.0189, -0.0244,  0.0386, -0.0084, -0.0992, -0.0358, -0.0255, -0.0109, -0.0223,  0.0658,  0.0210, -0.0090,  0.0513,  0.0648, -0.0428,  0.0203, -0.0385,  0.0421, -0.0428, -0.0597, -0.0281,  0.1150, -0.0488, -0.0213,  0.0962,  0.0037, -0.0588, -0.0330, -0.0141,  0.0646, -0.0179, -0.0203, -0.0204,  0.0413, -0.0339,  0.0377, -0.0797,  0.0210,  0.0395, -0.0612, -0.0291, -0.0481,  0.0418,  0.0170, -0.0641, -0.0139,  0.0308,  0.0076,  0.0104,  0.0205, -0.0043,  0.0222, -0.0228,  0.0337, -0.0021,  0.0628, -0.0111,  0.0041,  0.0705, -0.0225, -0.0248,  0.0287,  0.0709,  0.0109, -0.0558, -0.0489, -0.0321,  0.0767,  0.0377, -0.0531,  0.0612, -0.0467,  0.0343,  0.0195,  0.0741,  0.0963, -0.0236,  0.0707, -0.0491, -0.0407, -0.0370,  0.0806,  0.0174, -0.0130, -0.0194,  0.0155,  0.0348,  0.0540, -0.0357,  0.0176,  0.0408,  0.0997,  0.0262, -0.0179,  0.0663,  0.0956,  0.0476, -0.0568, -0.0305,  0.0063,  0.0643,  0.0656, -0.0427, -0.1023,  0.1351, -0.0445, -0.0047, -0.0674, -0.0750,  0.0591, -0.0946,  0.0220,  0.0023,  0.0086, -0.0895,  0.0430, -0.0613,  0.0638, -0.0318,  0.0536,  0.0216,  0.0350, -0.0043, -0.0470,  0.0360,  0.0944,  0.0140,  0.0915,  0.0719,  0.0168, -0.0117, -0.0369, -0.0402, -0.0139, -0.0954, -0.0463, -0.0829,  0.0891, -0.0709,  0.0022, -0.0028,  0.0302, -0.1730,  0.0585,  0.0022,  0.0357,  0.0248,  0.0274,  0.0361, -0.0570,  0.0508,  0.0770, -0.0273,  0.0608,  0.0062, -0.0077, -0.0392,  0.0320, -0.0385, -0.0652,  0.0106,  0.0014, -0.0817, -0.1166, -0.0085,  0.0298, -0.0669,  0.0334, -0.0173, -0.0321,  0.0299, -0.0333, -0.0049,  0.0614, -0.0467,  0.0267,  0.0398,  0.0070,  0.0534, -0.0180, -0.0445,  0.0437, -0.0795, -0.0665, -0.0085, -0.0171, -0.0207,  0.0128,  0.0639, -0.0206,  0.0176,  0.0282,  0.0749,  0.0757, -0.0355, -0.0445, -0.0050, -0.0335, -0.0774, -0.0544,  0.0394, -0.0315,  0.0321,  0.0474, -0.0200,  0.0276,  0.0329, -0.0617, -0.0847, -0.0258, -0.0408,  0.0513,  0.0118,  0.0024]]], device='mps', dtype=torch.float32, requires_grad=True)
# Compare the following:
print(repr(predict(model, MAGIC, 90)[0]))
print(repr(input_text))

tensor([[[ 1.8864e-01,  1.1630e-02, -1.8917e-02,  5.0527e-02,  4.3470e-02,  2.8263e-01, -4.3007e-02, -1.3526e-01,  2.7977e-02, -3.6790e-02,  1.4854e-01, -1.8984e-01,  2.8756e-01,  9.2666e-03, -6.2816e-02, -1.9388e-01,  1.3469e-01,  1.3483e-01,  5.7368e-02, -8.1124e-02, -1.3257e-01, -1.6508e-01,  1.4563e-01, -2.3088e-01,  1.3334e-01,  1.8826e-01,  8.9671e-02, -1.5939e-01, -8.0470e-02,  3.5115e-02, -3.0809e-01,  9.3301e-02, -9.1369e-02, -2.0310e-02, -1.3733e-01,  1.4521e-01, -1.4210e-01,  4.2299e-02, -1.2297e-01,  2.0042e-01, -2.0942e-01,  6.7794e-02, -1.2003e-01, -5.8772e-02, -6.3539e-02, -1.1414e-01, -6.9904e-02, -2.5917e-01,  9.3445e-02,  2.4092e-01,  4.7314e-02,  2.6794e-01, -1.2988e-01, -4.1823e-02, -6.5939e-02, -8.9051e-02,  1.2554e-01, -2.4257e-02,  1.3885e-01,  9.1738e-03, -3.2695e-01, -1.2574e-01, -1.3315e-01,  9.8945e-02,  2.5345e-01, -4.9235e-02, -3.7965e-02, -9.8906e-02, -2.4002e-02,  1.1618e-01,  1.4084e-01,  9.0612e-02,  1.4316e-01,  9.0570e-02,  1.1573e-01,  1.0323e-01,  2

RuntimeError: PyTorch is not linked with support for mps devices