# Adversarial embedding

In image, adversarial examples can readily constructed. In LLM, the embedding layer fixes the vector value for each token, making it hard to perform this attack. Now, how should we do this? However, it's still interesting to see how the transformer layers robust against perturbations inside the embedding layers.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_id = "Qwen/Qwen3-4B"
device = torch.device("mps")
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
original_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(model)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 2560)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=2560, out_features=4096, bias=False)
          (k_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=2560, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (up_proj): Linear(in_features=2560, out_features=9728, bias=False)
          (down_proj): Linear(in_features=9728, out_features=2560, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((2560,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [3]:
prompt = "When John and Mary went to the shop, John gave the bag to"
toks = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(prompt)
print(toks)

with torch.no_grad():
  out = model(toks)
  logits = out.logits
  print(tokenizer.decode(logits[0,-1].argmax()), logits[0,-1].argmax())

When John and Mary went to the shop, John gave the bag to
tensor([[ 4498,  3757,   323, 10244,  3937,   311,   279,  8061,    11,  3757,
          6551,   279,  8968,   311]], device='mps:0')
 Mary tensor(10244, device='mps:0')


## Intervene in the last token only

### Flip the prediction to a target by changing the embedding of last token

Now the question is how minimal the embedding of the last token be modified to sway the result not to "Mary" and steer towards "John". It takes quite a lot of changes to push the result to "John".

In [4]:
TARGET_LABEL_IDX = tokenizer.encode(" John")[0]
LAST_TOKEN_IDX = toks[0,-1].item()
mask = torch.zeros(model.model.embed_tokens.weight.shape[0], device=device, requires_grad=False).unsqueeze(1)
mask[LAST_TOKEN_IDX,0] = 1

for idx in range(100):
  out = model(toks)
  logits = out.logits
  with torch.no_grad():
    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    cos = F.cosine_similarity(model.model.embed_tokens.weight[LAST_TOKEN_IDX], original_model.model.embed_tokens.weight[LAST_TOKEN_IDX], dim=0)
    print(idx, out_toks, 'cosine:', cos.item())
    if topk.indices[0].item() == TARGET_LABEL_IDX:
      model.zero_grad()
      break

  loss = F.cross_entropy(logits[0,-1], torch.tensor(TARGET_LABEL_IDX, device=device))
  loss.backward()

  model.model.embed_tokens.weight.grad *= mask
  model.model.embed_tokens.weight.data = model.model.embed_tokens.weight - 1e-3 * model.model.embed_tokens.weight.grad

  model.zero_grad()

0 [' Mary', ' the', ' a'] cosine: 0.9999999403953552
1 [' Mary', ' the', ' a'] cosine: 0.999849259853363
2 [' Mary', ' the', ' a'] cosine: 0.9994126558303833
3 [' Mary', ' the', ' a'] cosine: 0.9987058043479919
4 [' Mary', ' the', ' a'] cosine: 0.9977388978004456
5 [' Mary', ' the', ' a'] cosine: 0.9965208172798157
6 [' Mary', ' the', ' a'] cosine: 0.9950621128082275
7 [' Mary', ' the', ' a'] cosine: 0.993376612663269
8 [' the', ' Mary', ' a'] cosine: 0.9914787411689758
9 [' the', ' Mary', ' a'] cosine: 0.9893377423286438
10 [' Mary', ' the', ' a'] cosine: 0.9867074489593506
11 [' the', ' Mary', ' a'] cosine: 0.9862926006317139
12 [' Mary', ' the', ' a'] cosine: 0.9832582473754883
13 [' the', ' Mary', ' a'] cosine: 0.9832898378372192
14 [' the', ' Mary', ' a'] cosine: 0.9798016548156738
15 [' the', ' Mary', ' a'] cosine: 0.976415753364563
16 [' the', ' Mary', ' a'] cosine: 0.9742664098739624
17 [' Mary', ' the', ' a'] cosine: 0.9703006744384766
18 [' Mary', ' the', ' a'] cosine: 0.9689

### Make the prediction not correct by changing the embedding of the last token

We would want to see if how hard it is to change the prediction from correct to incorrect. We will do so by update the embedding of the last token such that it increases the loss.

With Qwen3-4B, it takes only 4 iteration to knock " Mary" from the top prediction. After some time, the model predict the input token.

In [10]:
model.zero_grad()
with torch.no_grad():
  model.model.embed_tokens.weight.data = original_model.model.embed_tokens.weight.data.clone().detach()

In [11]:
CORRECT_LABEL_IDX = tokenizer.encode(" Mary")[0]
LAST_TOKEN_IDX = toks[0,-1].item()
mask = torch.zeros(model.model.embed_tokens.weight.shape[0], device=device, requires_grad=False).unsqueeze(1)
mask[LAST_TOKEN_IDX,0] = 1

for idx in range(100):
  out = model(toks)
  logits = out.logits
  with torch.no_grad():
    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    cos = F.cosine_similarity(model.model.embed_tokens.weight[LAST_TOKEN_IDX], original_model.model.embed_tokens.weight[LAST_TOKEN_IDX], dim=0)
    print(idx, out_toks, 'cosine:', cos.item())
    # if topk.indices[0].item() != CORRECT_LABEL_IDX:
    #   model.zero_grad()
    #   break

  loss = F.cross_entropy(logits[0,-1], torch.tensor(CORRECT_LABEL_IDX, device=device))
  loss.backward()

  model.model.embed_tokens.weight.grad *= mask
  model.model.embed_tokens.weight.data = model.model.embed_tokens.weight + 1e-3 * model.model.embed_tokens.weight.grad # here we increase the loss)

  model.zero_grad()

0 [' Mary', ' the', ' a'] cosine: 0.9999999403953552
1 [' Mary', ' the', ' a'] cosine: 0.9999438524246216
2 [' Mary', ' the', ' a'] cosine: 0.9997327327728271
3 [' the', ' Mary', ' a'] cosine: 0.9992661476135254
4 [' the', ' Mary', ' a'] cosine: 0.9983550310134888
5 [' the', ' Mary', ' a'] cosine: 0.9966298341751099
6 [' the', ' a', ' Mary'] cosine: 0.9933741688728333
7 [' the', ' a', ' Mary'] cosine: 0.9873926639556885
8 [' the', ' a', ' Alice'] cosine: 0.9774008989334106
9 [' the', ' a', ' an'] cosine: 0.9657547473907471
10 [' the', ' a', ' an'] cosine: 0.9581483602523804
11 [' the', ' a', ' Mary'] cosine: 0.9406974911689758
12 [' to', ' the', ' Mary'] cosine: 0.67897629737854
13 [' to', ' the', ' and'] cosine: 0.6499879360198975
14 [' to', ' the', ' and'] cosine: 0.6323633790016174
15 [' to', ' and', ','] cosine: 0.6142152547836304
16 [' to', ' and', ','] cosine: 0.6003431081771851
17 [' to', ' and', ','] cosine: 0.5893852710723877
18 [' to', ' and', ','] cosine: 0.5773136019706726


### Randomly jittering the embedding of the last token.

The prediction is very robust against random jittering. Randomly jittering within 100% of value doesn't change much the cosine similarity. It seems the before and after embeddings still point to relatively similar direction. Increasing the magnitude of random jittering does decreases before-after cosine similarity, and has some effect on the prediction, but not by much.

This is pretty interesting. Maybe during training, the self-attention constantly "induces noises", and this makes the model much more robust against random jittering.

In [39]:
@torch.no_grad
def randomize_embedding(target_emb, src_emb, token_idx, pct: float):
  token_emb = src_emb.weight.data[token_idx].clone()
  jittering = torch.rand(token_emb.shape) * pct * 2 + 1 - pct
  jittering = jittering.to(token_emb.device)
  target_emb.weight.data[token_idx] = token_emb * jittering

LAST_TOKEN_IDX = toks[0,-1].item()

with torch.no_grad():
  for idx in range(200):
    randomize_embedding(model.model.embed_tokens, original_model.model.embed_tokens, LAST_TOKEN_IDX, idx / 100)
    out = model(toks)
    logits = out.logits

    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    cos = F.cosine_similarity(model.model.embed_tokens.weight[LAST_TOKEN_IDX], original_model.model.embed_tokens.weight[LAST_TOKEN_IDX], dim=0)
    print(idx, out_toks, 'cosine:', cos.item())

0 [' Mary', ' the', ' a'] cosine: 0.9999999403953552
1 [' Mary', ' the', ' a'] cosine: 0.9999831318855286
2 [' Mary', ' the', ' a'] cosine: 0.9999349117279053
3 [' Mary', ' the', ' a'] cosine: 0.9998526573181152
4 [' Mary', ' the', ' a'] cosine: 0.9997234344482422
5 [' Mary', ' the', ' a'] cosine: 0.9995722770690918
6 [' Mary', ' the', ' a'] cosine: 0.9994065761566162
7 [' Mary', ' the', ' a'] cosine: 0.9992039203643799
8 [' Mary', ' the', ' a'] cosine: 0.9988850355148315
9 [' Mary', ' the', ' a'] cosine: 0.9987077713012695
10 [' Mary', ' the', ' a'] cosine: 0.9983139634132385
11 [' Mary', ' the', ' a'] cosine: 0.9979837536811829
12 [' Mary', ' the', ' a'] cosine: 0.9976602792739868
13 [' Mary', ' the', ' a'] cosine: 0.9971832633018494
14 [' Mary', ' the', ' a'] cosine: 0.9967246651649475
15 [' Mary', ' the', ' a'] cosine: 0.9962208271026611
16 [' Mary', ' the', ' a'] cosine: 0.995842456817627
17 [' Mary', ' the', ' a'] cosine: 0.9951059222221375
18 [' Mary', ' the', ' a'] cosine: 0.99

In [41]:
print("Before", original_model.model.embed_tokens.weight[LAST_TOKEN_IDX])

Before tensor([-0.0181, -0.0189,  0.0119,  ...,  0.0098, -0.0247,  0.0221],
       device='mps:0', grad_fn=<SelectBackward0>)


In [42]:
print("After", model.model.embed_tokens.weight[LAST_TOKEN_IDX])

After tensor([-0.0014, -0.0250, -0.0112,  ...,  0.0125, -0.0251,  0.0121],
       device='mps:0', grad_fn=<SelectBackward0>)


In [43]:
with torch.no_grad():
  print("Norm before", idx / 100, "jittering", original_model.model.embed_tokens.weight[LAST_TOKEN_IDX].norm(p=2))
  print("Norm after", idx / 100, "jittering", model.model.embed_tokens.weight[LAST_TOKEN_IDX].norm(p=2))

Norm before 1.99 jittering tensor(0.8725, device='mps:0')
Norm after 1.99 jittering tensor(1.3438, device='mps:0')


## Intervene embedding of all tokens

The previous experiments attempt to make minimal change to the last token to see how hard it's to flip the correct answer. The experiments below do the same but allow make changes to embeddings of all tokens.

In [92]:
@torch.no_grad()
def summarize_grad_accumulation(grad_accumulation, model, original_model):
  grad_changes = (grad_accumulation ** 2).sum(dim=1)
  grad_changes_mask = (grad_changes > 0.01)
  changed_tokens = grad_changes_mask.nonzero().squeeze()

  print(changed_tokens.shape[0], "tokens have embedding changed. They are:")
  _changed = []
  for t_ in changed_tokens.cpu().tolist():
    cos_ = F.cosine_similarity(original_model.model.embed_tokens.weight[t_], model.model.embed_tokens.weight[t_], dim=0).item()
    _changed.append((t_, tokenizer.decode(t_), grad_accumulation[t_,0].norm(p=1).item(), cos_))

  print("(idx, token str, grad_accumulation p1, cos before-after)")
  print()
  for t_, s_, v_, c_ in sorted(_changed, key=lambda obj: obj[3]):
    print(t_, s_, v_, c_)

### Flip the prediction to a target token

It's much easier to flip the token if we change all of the embeddings because it contains the token that we want to flip into.

In [48]:
model.zero_grad()
with torch.no_grad():
  model.model.embed_tokens.weight.data = original_model.model.embed_tokens.weight.data.clone()

In [49]:
TARGET_LABEL_IDX = tokenizer.encode(" John")[0]
grad_accumulation = torch.zeros(model.model.embed_tokens.weight.data.shape, requires_grad=False).to(device)

for idx in range(100):
  out = model(toks)
  logits = out.logits
  with torch.no_grad():
    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    print(idx, out_toks)
    if topk.indices[0].item() == TARGET_LABEL_IDX:
      model.zero_grad()
      break

  loss = F.cross_entropy(logits[0,-1], torch.tensor(TARGET_LABEL_IDX, device=device))
  loss.backward()

  with torch.no_grad():
    grad_accumulation += model.model.embed_tokens.weight.grad
  model.model.embed_tokens.weight.data = model.model.embed_tokens.weight - 1e-3 * model.model.embed_tokens.weight.grad
  model.zero_grad()

0 [' Mary', ' the', ' a']
1 [' John', 'Mary', ' his']


In [50]:
grad_changes = (grad_accumulation ** 2).sum(dim=1)
grad_changes_mask = (grad_changes > 0.01)
changed_tokens = grad_changes_mask.nonzero().squeeze()

In [73]:
print(changed_tokens.shape[0], "tokens have embedding changed. They are:")
_changed = []
for t_ in changed_tokens.cpu().tolist():
  cos_ = F.cosine_similarity(original_model.model.embed_tokens.weight[t_], model.model.embed_tokens.weight[t_], dim=0).item()
  _changed.append((t_, tokenizer.decode(t_), grad_accumulation[t_,0].norm(p=1).item(), cos_))

print("idx, token str, grad_accumulation p1, cos before-after")
print()
for t_, s_, v_, c_ in sorted(_changed, key=lambda obj: obj[3]):
  print(t_, s_, v_, c_)

25 tokens have embedding changed. They are:
idx, token str, grad_accumulation p1, cos before-after

3757  John 1.466503381729126 0.9879211783409119
10244  Mary 0.27848052978515625 0.9963998794555664
279  the 1.0227277278900146 0.9983216524124146
4498 When 1.4908576011657715 0.9995152950286865
8061  shop 0.0740446001291275 0.9997081160545349
3937  went 0.3508458435535431 0.9998031854629517
6551  gave 1.1857423782348633 0.9998382329940796
323  and 0.7384849190711975 0.9998487234115601
311  to 0.00997190922498703 0.999849259853363
8968  bag 0.4924362301826477 0.9999417662620544
11 , 0.23656544089317322 0.9999722838401794
264  a 0.005127974320203066 0.9999829530715942
806  his 0.0041706436313688755 0.9999920725822449
29405  Alice 0.0010434952564537525 0.9999995827674866
8224  Sam 0.00013742889859713614 0.9999998807907104
20445  Sarah 0.0005189783405512571 0.9999998807907104
7801  James 0.0002439873933326453 0.9999999403953552
23016  Maria 0.00013226382725406438 0.9999999403953552
458  an 0

It seems the token " John" get massive update. What if we suppress changing the token " John"? The " to" token is not updated as much as we expect (we expect because this token lays the ground for prediction " Mary"). The " gave" token is updated quite a lot, which is reasonable now when I think of it.

One interesting thing is that some tokens that do not appear in this sentence get gradient?

In [74]:
t1 = tokenizer.encode(" John")[0]
t2 = tokenizer.encode(" Mary")[0]

with torch.no_grad():
  # before
  print("John-Mary cos Before", F.cosine_similarity(original_model.model.embed_tokens.weight[t1], original_model.model.embed_tokens.weight[t2], dim=0))
  print("John-Mary cos After", F.cosine_similarity(model.model.embed_tokens.weight[t1], model.model.embed_tokens.weight[t2], dim=0))

John-Mary cos Before tensor(0.2099, device='mps:0')
John-Mary cos After tensor(0.2098, device='mps:0')


#### Suppressing update embedding of the target token

Let's check how hard it is if the embedding of the target token " John" not modified. It takes more effort, but not that much. But the end result is that it requires the embedding of much more tokens to change to get this effect.

In [76]:
model.zero_grad()
with torch.no_grad():
  model.model.embed_tokens.weight.data = original_model.model.embed_tokens.weight.data.clone()

In [77]:
TARGET_LABEL_IDX = tokenizer.encode(" John")[0]
mask = torch.ones(model.model.embed_tokens.weight.shape[0], device=device, requires_grad=False).unsqueeze(1)
mask[TARGET_LABEL_IDX,0] = 0

grad_accumulation = torch.zeros(model.model.embed_tokens.weight.data.shape, requires_grad=False).to(device)

for idx in range(100):
  out = model(toks)
  logits = out.logits
  with torch.no_grad():
    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    print(idx, out_toks)
    if topk.indices[0].item() == TARGET_LABEL_IDX:
      model.zero_grad()
      break

  loss = F.cross_entropy(logits[0,-1], torch.tensor(TARGET_LABEL_IDX, device=device))
  loss.backward()

  model.model.embed_tokens.weight.grad *= mask

  with torch.no_grad():
    grad_accumulation += model.model.embed_tokens.weight.grad
  
  model.model.embed_tokens.weight.data = model.model.embed_tokens.weight - 1e-3 * model.model.embed_tokens.weight.grad
  model.zero_grad()

0 [' Mary', ' the', ' a']
1 [' a', ' his', ' Alice']
2 ['Mary', ' his', ' a']
3 [' mary', 'Mary', ' his']
4 [' James', 'Mary', ' mary']
5 [' mary', 'Mary', ' John']
6 [' James', ' John', ' mary']
7 [' John', ' James', ' Mark']


In [78]:
grad_changes = (grad_accumulation ** 2).sum(dim=1)
grad_changes_mask = (grad_changes > 0.01)
changed_tokens = grad_changes_mask.nonzero().squeeze()

print(changed_tokens.shape[0], "tokens have embedding changed. They are:")
_changed = []
for t_ in changed_tokens.cpu().tolist():
  cos_ = F.cosine_similarity(original_model.model.embed_tokens.weight[t_], model.model.embed_tokens.weight[t_], dim=0).item()
  _changed.append((t_, tokenizer.decode(t_), grad_accumulation[t_,0].norm(p=1).item(), cos_))

print("idx, token str, grad_accumulation p1, cos before-after")
print()
for t_, s_, v_, c_ in sorted(_changed, key=lambda obj: obj[3]):
  print(t_, s_, v_, c_)

526 tokens have embedding changed. They are:
idx, token str, grad_accumulation p1, cos before-after

10244  Mary 2.1146278381347656 0.9933505654335022
279  the 1.5858529806137085 0.9963732957839966
323  and 0.7634978890419006 0.9983670711517334
4498 When 3.5042643547058105 0.9983714818954468
8061  shop 0.4943886399269104 0.9991201162338257
3937  went 0.7949550151824951 0.9994505643844604
6551  gave 0.6530000567436218 0.999481737613678
311  to 0.15373677015304565 0.999482274055481
264  a 0.033509328961372375 0.9995958805084229
806  his 0.034367725253105164 0.9997066259384155
8968  bag 0.3553487956523895 0.9997640252113342
41484 Mary 0.029053980484604836 0.9997926950454712
7801  James 0.02342626452445984 0.9998242855072021
29405  Alice 0.02315627969801426 0.9998542666435242
11 , 0.18387183547019958 0.9999076128005981
84630  mary 0.021174168214201927 0.999914288520813
20445  Sarah 0.01612580008804798 0.9999176263809204
21475  Jane 0.015967853367328644 0.9999343156814575
23016  Maria 0.015

### Make the prediction not correct

It's relatively easy to make the prediction not correct.

In [82]:
model.zero_grad()
with torch.no_grad():
  model.model.embed_tokens.weight.data = original_model.model.embed_tokens.weight.data.clone().detach()

In [83]:
CORRECT_LABEL_IDX = tokenizer.encode(" Mary")[0]
grad_accumulation = torch.zeros(model.model.embed_tokens.weight.data.shape, requires_grad=False).to(device)

for idx in range(100):
  out = model(toks)
  logits = out.logits
  with torch.no_grad():
    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    print(idx, out_toks)
    if topk.indices[0].item() != CORRECT_LABEL_IDX:
      model.zero_grad()
      break

  loss = F.cross_entropy(logits[0,-1], torch.tensor(CORRECT_LABEL_IDX, device=device))
  loss.backward()
  with torch.no_grad():
    grad_accumulation += model.model.embed_tokens.weight.grad

  model.model.embed_tokens.weight.data = model.model.embed_tokens.weight + 1e-3 * model.model.embed_tokens.weight.grad
  model.zero_grad()

0 [' Mary', ' the', ' a']
1 [' the', ' a', ' his']


In [91]:
summarize_grad_accumulation(grad_accumulation, model, original_model)

25 tokens have embedding changed. They are:
(idx, token str, grad_accumulation p1, cos before-after)

10244  Mary 0.46845513582229614 0.9986592531204224
279  the 0.03706904500722885 0.9988864660263062
8061  shop 0.11342021822929382 0.9998895525932312
4498 When 0.05344673991203308 0.9999078512191772
3937  went 0.032146092504262924 0.9999401569366455
311  to 0.2873799502849579 0.9999438524246216
3757  John 0.04006228968501091 0.9999444484710693
6551  gave 0.18170922994613647 0.9999549388885498
8968  bag 0.09388038516044617 0.9999561309814453
323  and 0.15183570981025696 0.9999767541885376
264  a 0.005127974320203066 0.9999830722808838
806  his 0.0041706436313688755 0.9999920725822449
11 , 0.04456475377082825 0.999993085861206
29405  Alice 0.0010434952564537525 0.9999995231628418
20445  Sarah 0.0005189783405512571 0.9999998211860657
8224  Sam 0.00013742889859713614 0.9999998807907104
7801  James 0.0002439873933326453 0.9999999403953552
458  an 0.00014632524107582867 1.0
4325  someone 0.00

### Randomly jittering the embedding

We know that just randomly jittering the embedding of the last token is already too hard to flip prediction. What if randomly jitter embeddings of all the tokens in the sentence?

Still very hard to change prediction.

In [98]:
@torch.no_grad
def randomize_embedding_multiple_toks(target_emb, src_emb, token_idxs, pct: float):
  for token_idx in token_idxs:
    token_emb = src_emb.weight.data[token_idx].clone()
    jittering = torch.rand(token_emb.shape) * pct * 2 + 1 - pct
    jittering = jittering.to(token_emb.device)
    target_emb.weight.data[token_idx] = token_emb * jittering

ALL_TOKS = list(set(toks.squeeze().cpu().tolist()))

with torch.no_grad():
  for idx in range(200):
    randomize_embedding_multiple_toks(model.model.embed_tokens, original_model.model.embed_tokens, ALL_TOKS, idx / 100)
    out = model(toks)
    logits = out.logits

    topk = logits[0,-1].topk(3)
    out_toks = tokenizer.batch_decode(topk.indices)
    cos = F.cosine_similarity(model.model.embed_tokens.weight[LAST_TOKEN_IDX], original_model.model.embed_tokens.weight[LAST_TOKEN_IDX], dim=0)
    print(idx, out_toks, 'cosine last tok:', cos.item())

0 [' Mary', ' the', ' a'] cosine last tok: 0.9999999403953552
1 [' Mary', ' the', ' a'] cosine last tok: 0.9999843835830688
2 [' Mary', ' the', ' a'] cosine last tok: 0.9999337196350098
3 [' Mary', ' the', ' a'] cosine last tok: 0.999849259853363
4 [' Mary', ' the', ' a'] cosine last tok: 0.9997230768203735
5 [' Mary', ' the', ' a'] cosine last tok: 0.9995887279510498
6 [' Mary', ' the', ' a'] cosine last tok: 0.9994329810142517
7 [' Mary', ' the', ' a'] cosine last tok: 0.9991805553436279
8 [' Mary', ' the', ' a'] cosine last tok: 0.9989150762557983
9 [' Mary', ' the', ' a'] cosine last tok: 0.9987057447433472
10 [' Mary', ' the', ' a'] cosine last tok: 0.9983136653900146
11 [' Mary', ' the', ' a'] cosine last tok: 0.9980168342590332
12 [' Mary', ' the', ' a'] cosine last tok: 0.9976080656051636
13 [' Mary', ' the', ' a'] cosine last tok: 0.9972718358039856
14 [' Mary', ' the', ' a'] cosine last tok: 0.9966927170753479
15 [' Mary', ' the', ' a'] cosine last tok: 0.9963391423225403
16 