## Imports and download

In [2]:
import transformers
import torch
from torch.optim import AdamW
from tqdm import tqdm

In [3]:
model_checkpoint = "distilgpt2"
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [4]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## Hook into model and optimize sentence

In [None]:
def unembed(embeds_input):
  """
  Given an embedding vector, decode each token by using the transpose of the embedding matrix
  and grabbing the vocab token with the highest probability on each token.

  Also do this with the unembedding matrix as well.
  """
  with torch.no_grad():
      # Get the pre-trained embeddings
      pretrained_embeddings = model.transformer.wte.weight

      # Calculate dot product between input embeddings and pre-trained embeddings
      dot_product = torch.matmul(embeds_input, pretrained_embeddings.t())

      # Get the index of the highest value along dimension 2 (tokens)
      _, tokens = torch.max(dot_product, dim=2)

  # Decode tokens into text using the tokenizer
  text = tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True)

  return text


In [9]:
layer = model.transformer.h[1].mlp

In [10]:
activation_saved = [0.0]
def hook(model, input, output):
  # The output is a tensor. You can index it to get the activation of a specific neuron.
  # Here we're getting the activation of the 0th neuron.
  activation = output[0, 0, 0]  # Gets the the 0th neuron of the 0th example in the batch
  activation_saved[0] = activation


In [14]:
# Register the hook
try:
  handle.remove() # deletes the handle when we are done with it.
except:
  pass
handle = layer.register_forward_hook(hook)

In [27]:
inputs = tokenizer("Hello, world!", return_tensors="pt")

# Get embeddings
with torch.no_grad():
    embeddings = model.transformer.wte(inputs["input_ids"])

# Make embeddings require gradient
embeddings.requires_grad_(True)

# Create an optimizer for the embeddings
optimizer = AdamW([embeddings], lr=1e-3)  # You may need to adjust the learning rate
pre_embeddings = embeddings.detach().clone()
print(embeddings)
print(unembed(pre_embeddings))

tensor([[[-0.0904, -0.1538,  0.0315,  ...,  0.0774, -0.0212, -0.0622],
         [ 0.0086, -0.0009,  0.0056,  ...,  0.0484, -0.0737, -0.0636],
         [-0.1725,  0.1922, -0.0372,  ..., -0.3523,  0.1989,  0.0269],
         [-0.1445, -0.0455,  0.0042,  ..., -0.1523,  0.0184,  0.0991]]],
       requires_grad=True)
['Hello, world!']


In [36]:
dist = 0.0
losses = []
for i in tqdm(range(1000)):
  outputs = model(inputs_embeds=embeddings, attention_mask=inputs.attention_mask)
  loss = activation_saved[0]
  loss.backward()
  optimizer.step()
  dist = torch.sum(embeddings - pre_embeddings).item()
  losses.append(loss)
  if i % 25 == 0:
    tqdm.write(f"\n{dist} and then {loss}")
    tqdm.write(unembed(embeddings)[0])
  optimizer.zero_grad()


  0%|          | 2/1000 [00:00<02:29,  6.69it/s]


2.7882938385009766 and then -19.504741668701172
Hello, world!


  3%|▎         | 26/1000 [00:03<03:53,  4.18it/s]


3.07511568069458 and then -22.590717315673828
Hello, world!


  5%|▌         | 51/1000 [00:10<03:33,  4.45it/s]


2.9936022758483887 and then -25.630956649780273
Hello, world!


  8%|▊         | 77/1000 [00:14<02:04,  7.40it/s]


2.8437581062316895 and then -28.83483123779297
Hello, world!


 10%|█         | 101/1000 [00:18<02:34,  5.83it/s]


2.955322265625 and then -31.79642105102539
Hello, world!


 13%|█▎        | 126/1000 [00:22<02:10,  6.69it/s]


3.09163236618042 and then -34.41019058227539
Hello, world!


 15%|█▌        | 151/1000 [00:26<01:34,  8.95it/s]


3.2975564002990723 and then -36.8447380065918
Hello, world!


 18%|█▊        | 177/1000 [00:29<01:26,  9.57it/s]


3.650947093963623 and then -39.280479431152344
Hello, world!


 20%|██        | 201/1000 [00:31<01:22,  9.64it/s]


4.020668983459473 and then -41.6020622253418
Hello, world!


 23%|██▎       | 226/1000 [00:35<02:24,  5.37it/s]


4.214365005493164 and then -43.8235969543457
Hello, world!


 25%|██▌       | 252/1000 [00:39<01:48,  6.92it/s]


4.230597972869873 and then -45.99296569824219
Hello, world!


 28%|██▊       | 276/1000 [00:41<01:11, 10.16it/s]


4.341536521911621 and then -48.01271438598633
 Thumbnails, world!


 30%|███       | 301/1000 [00:44<01:13,  9.51it/s]


4.607449531555176 and then -50.02605438232422
 Thumbnails, world!


 33%|███▎      | 327/1000 [00:46<01:09,  9.72it/s]


4.857081890106201 and then -51.72468948364258
 Thumbnails, world!


 35%|███▌      | 352/1000 [00:49<01:08,  9.40it/s]


5.018745422363281 and then -53.10089111328125
 Thumbnails, world!


 38%|███▊      | 376/1000 [00:53<02:12,  4.72it/s]


5.139378547668457 and then -54.26509094238281
 Thumbnails, world!


 40%|████      | 401/1000 [00:56<01:08,  8.81it/s]


5.250577449798584 and then -55.30712890625
 Thumbnails, world!


 43%|████▎     | 427/1000 [00:59<00:58,  9.86it/s]


5.369259834289551 and then -56.300384521484375
 Thumbnails, world!


 45%|████▌     | 451/1000 [01:02<00:59,  9.22it/s]


5.505671977996826 and then -57.29262161254883
 Thumbnails, world!


 48%|████▊     | 476/1000 [01:05<01:01,  8.55it/s]


5.628506183624268 and then -58.29790115356445
 sidx, world!


 50%|█████     | 502/1000 [01:08<01:07,  7.43it/s]


5.7250075340271 and then -59.360107421875
 sidx, world!


 53%|█████▎    | 526/1000 [01:11<00:53,  8.82it/s]


5.897246360778809 and then -60.438682556152344
 sidx, world!


 55%|█████▌    | 552/1000 [01:14<00:44, 10.05it/s]


6.06976318359375 and then -61.50028991699219
 sidx, world!


 58%|█████▊    | 576/1000 [01:16<00:50,  8.37it/s]


6.153295040130615 and then -62.52595138549805
 sidx, world!


 60%|██████    | 601/1000 [01:19<00:50,  7.96it/s]


6.161953926086426 and then -63.42646026611328
 sidx, world!


 63%|██████▎   | 627/1000 [01:22<00:54,  6.81it/s]


6.127947807312012 and then -64.1803970336914
gypt, world!


 65%|██████▌   | 652/1000 [01:25<00:38,  9.03it/s]


6.037476539611816 and then -64.8541030883789
gypt, world!


 68%|██████▊   | 676/1000 [01:28<00:49,  6.59it/s]


5.916409969329834 and then -65.47130584716797
gypt, world!


 70%|███████   | 701/1000 [01:31<00:36,  8.27it/s]


5.826769828796387 and then -66.04135131835938
gypt, world!


 73%|███████▎  | 727/1000 [01:34<00:28,  9.74it/s]


5.775008678436279 and then -66.57836151123047
gypt, world!


 75%|███████▌  | 751/1000 [01:38<00:45,  5.51it/s]


5.736426830291748 and then -67.08463287353516
gypt, world!


 78%|███████▊  | 776/1000 [01:42<00:23,  9.68it/s]


5.699711799621582 and then -67.56282806396484
gypt, world!


 80%|████████  | 801/1000 [01:44<00:22,  8.66it/s]


5.663791179656982 and then -68.0208969116211
gypt, world!


 83%|████████▎ | 827/1000 [01:47<00:17,  9.67it/s]


5.623499870300293 and then -68.47293090820312
resents, world!


 85%|████████▌ | 852/1000 [01:51<00:22,  6.46it/s]


5.576695442199707 and then -68.93406677246094
resents, world!


 88%|████████▊ | 877/1000 [01:55<00:21,  5.85it/s]


5.542847633361816 and then -69.41598510742188
resents, world!


 90%|█████████ | 902/1000 [01:58<00:15,  6.39it/s]


5.553949356079102 and then -69.94158935546875
resents, world!


 93%|█████████▎| 927/1000 [02:02<00:11,  6.46it/s]


5.636106491088867 and then -70.55557250976562
resents, world!


 95%|█████████▌| 951/1000 [02:05<00:07,  6.24it/s]


5.759215354919434 and then -71.21534729003906
resents, world!


 98%|█████████▊| 977/1000 [02:09<00:03,  7.32it/s]


5.9168291091918945 and then -71.8554916381836
resents, world!


100%|██████████| 1000/1000 [02:12<00:00,  7.56it/s]


In [29]:
embeddings

tensor([[[ 0.0199, -0.0337,  0.0911,  ...,  0.1578, -0.1568, -0.1923],
         [ 0.0086, -0.0009,  0.0056,  ...,  0.0483, -0.0737, -0.0636],
         [-0.1723,  0.1921, -0.0371,  ..., -0.3519,  0.1987,  0.0269],
         [-0.1443, -0.0455,  0.0042,  ..., -0.1521,  0.0184,  0.0990]]],
       requires_grad=True)

In [18]:
embeddings.shape

torch.Size([1, 4, 768])

In [19]:
embed_matrix = model.transformer.wte.weight

In [21]:
embed_matrix.shape

torch.Size([50257, 768])

tensor([[[ 1.1032e-01,  1.2014e-01,  5.9592e-02,  ...,  8.0406e-02,
          -1.3566e-01, -1.3011e-01],
         [-8.6594e-06,  8.7894e-07, -5.6089e-06,  ..., -4.8429e-05,
           7.3761e-05,  6.3330e-05],
         [ 1.7285e-04, -1.9222e-04,  3.7253e-05,  ...,  3.5167e-04,
          -1.9968e-04, -2.7008e-05],
         [ 1.4454e-04,  4.5449e-05, -4.2375e-06,  ...,  1.5199e-04,
          -1.8440e-05, -9.9093e-05]]], grad_fn=<SubBackward0>)