## Imports and download

In [2]:
import transformers
import torch
from torch.optim import AdamW
from tqdm import tqdm

In [3]:
model_checkpoint = "distilgpt2"
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [10]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## Hook into model and optimize sentence

In [31]:
def unembed(embeds_input):
  """
  Given an embedding vector, decode each token by using the transpose of the embedding matrix
  and grabbing the vocab token with the highest probability on each token.

  Also do this with the unembedding matrix as well.
  """
  with torch.no_grad():
      # Get the pre-trained embeddings
      pretrained_embeddings = model.transformer.wte.weight

      # Calculate dot product between input embeddings and pre-trained embeddings
      dot_product = torch.matmul(embeds_input, pretrained_embeddings.t())

      # Get the index of the highest value along dimension 2 (tokens)
      _, tokens = torch.max(dot_product, dim=2)

  # Decode tokens into text using the tokenizer
  text = tokenizer.batch_decode(tokens.tolist(), skip_special_tokens=True)

  return text


In [32]:
layer = model.transformer.h[1].mlp

In [33]:
activation_saved = [0.0]
def hook(model, input, output):
  # The output is a tensor. You can index it to get the activation of a specific neuron.
  # Here we're getting the activation of the 0th neuron.
  activation = output[0, 0, 0]  # Gets the the 0th neuron of the 0th example in the batch
  activation_saved[0] = activation


In [34]:
# Register the hook
try:
  handle.remove() # deletes the handle when we are done with it.
except:
  pass
handle = layer.register_forward_hook(hook)

In [35]:
inputs = tokenizer("Hello, world!", return_tensors="pt")

# Get embeddings
with torch.no_grad():
    embeddings = model.transformer.wte(inputs["input_ids"])

# Make embeddings require gradient
embeddings.requires_grad_(True)

# Create an optimizer for the embeddings
optimizer = AdamW([embeddings], lr=1e-3)  # You may need to adjust the learning rate
pre_embeddings = embeddings.detach().clone()
print(embeddings)
print(unembed(pre_embeddings))
#print((embeddings != unembed(pre_embeddings)).float().sum())

tensor([[[-0.0904, -0.1538,  0.0315,  ...,  0.0774, -0.0212, -0.0622],
         [ 0.0086, -0.0009,  0.0056,  ...,  0.0484, -0.0737, -0.0636],
         [-0.1725,  0.1922, -0.0372,  ..., -0.3523,  0.1989,  0.0269],
         [-0.1445, -0.0455,  0.0042,  ..., -0.1523,  0.0184,  0.0991]]],
       requires_grad=True)
['Hello, world!']


In [36]:
from tqdm.notebook import tqdm, trange

dist = 0.0
losses = []
pbar = trange(1000, desc="Processing")
for i in pbar:
    outputs = model(inputs_embeds=embeddings, attention_mask=inputs.attention_mask)
    loss = activation_saved[0]
    loss.backward()
    optimizer.step()
    dist = torch.sum(embeddings - pre_embeddings).item()
    losses.append(loss)
    if i % 25 == 0:
        pbar.set_description(f"Processing (dist={dist}, loss={loss})")
        pbar.set_postfix_str(unembed(embeddings)[0])
    optimizer.zero_grad()


Processing:   0%|          | 0/1000 [00:00<?, ?it/s]

In [29]:
embeddings

tensor([[[ 0.0199, -0.0337,  0.0911,  ...,  0.1578, -0.1568, -0.1923],
         [ 0.0086, -0.0009,  0.0056,  ...,  0.0483, -0.0737, -0.0636],
         [-0.1723,  0.1921, -0.0371,  ..., -0.3519,  0.1987,  0.0269],
         [-0.1443, -0.0455,  0.0042,  ..., -0.1521,  0.0184,  0.0990]]],
       requires_grad=True)

In [18]:
embeddings.shape

torch.Size([1, 4, 768])

In [19]:
embed_matrix = model.transformer.wte.weight

In [21]:
embed_matrix.shape

torch.Size([50257, 768])

tensor([[[ 1.1032e-01,  1.2014e-01,  5.9592e-02,  ...,  8.0406e-02,
          -1.3566e-01, -1.3011e-01],
         [-8.6594e-06,  8.7894e-07, -5.6089e-06,  ..., -4.8429e-05,
           7.3761e-05,  6.3330e-05],
         [ 1.7285e-04, -1.9222e-04,  3.7253e-05,  ...,  3.5167e-04,
          -1.9968e-04, -2.7008e-05],
         [ 1.4454e-04,  4.5449e-05, -4.2375e-06,  ...,  1.5199e-04,
          -1.8440e-05, -9.9093e-05]]], grad_fn=<SubBackward0>)