## This model findes n embeddings to make gpt2 produce any sentence

In [3]:
import os, torch
from ipywidgets import Dropdown
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F
from torch.optim import SGD
import torch.nn as nn

In [4]:
num_magic_tokens = 2

In [5]:
def load_model(model_path):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    return model


def load_tokenizer(tokenizer_path):
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    return tokenizer



In [6]:
tokenizer = load_tokenizer( "gpt2" )

In [7]:
model = load_model( "gpt2" )

In [8]:
zero_index = torch.LongTensor([0])

In [23]:
#The learned magic tokens will end up in this list
learned_magic_token_list = []
sentence_to_produce = "This sentence will become an embedding."

In [24]:
print( f"\n\nWorking on sentence {sentence_to_produce}" )

tokenized = tokenizer( sentence_to_produce )

slope = -.1


for run in range( 100 ):
    model.zero_grad()
    print( f"\n{run}: ", end='' )

    
    # if learned_magic_token_list:
    #     print( f" First couple learned values of magic token 0: {learned_magic_token_list[0][:5].detach().numpy()}", end='\n' )

    found_problem = False
    
    for token_to_teach in range( len(tokenized["input_ids"]) ):

        input_ids = torch.LongTensor(tokenized["input_ids"][:token_to_teach])
        correct_token = tokenized["input_ids"][token_to_teach]
        target_tensor = torch.zeros(len(tokenizer))
        target_tensor[correct_token] = 25.0

        inputs_embeds_list = []
        for magic_token_i in range( num_magic_tokens ):
            while magic_token_i >= len( learned_magic_token_list ):
                new_magic_token = torch.randn_like( model.get_input_embeddings()(zero_index)[0] )
                new_magic_token.requires_grad = True
                new_magic_token.retain_grad()
                learned_magic_token_list.append( new_magic_token )
            inputs_embeds_list.append( learned_magic_token_list[magic_token_i] )
        for input_index in range(token_to_teach):
            inputs_embeds_list.append( model.get_input_embeddings()( torch.LongTensor([tokenized["input_ids"][input_index]]) )[0] )
            
        inputs_embeds = torch.stack(inputs_embeds_list, dim=0)
        #inputs_embeds.requires_grad = True
        #inputs_embeds.retain_grad()
        result = model.forward( inputs_embeds = inputs_embeds )
        loss = F.cross_entropy( result.logits[-1].unsqueeze(0), target_tensor.unsqueeze(0) )
        loss.backward()
    
        probs = F.softmax(result.logits[-1], dim=-1)
        #sampled_token_id = torch.multinomial(probs, 1).item()
        sampled_token_id = torch.argmax(probs, dim=-1).item()
        print(tokenizer.decode( [sampled_token_id] ), end='')

        if sampled_token_id != correct_token:
            found_problem = True
            print( f"\nbreaking because {tokenizer.decode([sampled_token_id])} != {tokenizer.decode([correct_token])}" )
            break
    
    # new_inputs_embeds = inputs_embeds + (inputs_embeds.grad * slope)

    # for magic_token_i in range(len(magic_tokens)):
    #     magic_token = magic_tokens[magic_token_i]
    #     learned_magic_tokens[magic_token] = new_inputs_embeds[magic_token_i].detach()

    for magic_token_i in range( num_magic_tokens ):
        magic_token_tensor = learned_magic_token_list[magic_token_i]
        new_magic_token_tensor = (magic_token_tensor + (magic_token_tensor.grad * slope)).detach()
        new_magic_token_tensor.requires_grad = True
        new_magic_token_tensor.retain_grad()
        learned_magic_token_list[magic_token_i] = new_magic_token_tensor

    if not found_problem: break




Working on sentence This sentence will become an embedding.

0:  the
breaking because  the != This

1: The
breaking because The != This

2: 

breaking because 
 != This

3: The
breaking because The != This

4: The
breaking because The != This

5: The
breaking because The != This

6: The
breaking because The != This

7: This is
breaking because  is !=  sentence

8: This.
breaking because . !=  sentence

9: This is
breaking because  is !=  sentence

10: This is
breaking because  is !=  sentence

11: The
breaking because The != This

12: This.
breaking because . !=  sentence

13: This is
breaking because  is !=  sentence

14: This Act
breaking because  Act !=  sentence

15: This sentence:
breaking because : !=  will

16: This is
breaking because  is !=  sentence

17: This sentence is
breaking because  is !=  will

18: This sentence is
breaking because  is !=  will

19: This sentence is
breaking because  is !=  will

20: This sentence will be
breaking because  be !=  become

21: This is


In [25]:
print( f"The embedding for the sentence {sentence_to_produce} is \n{learned_magic_token_list}" )

The embedding for the sentence This sentence will become an embedding. is 
[tensor([ 0.1929, -0.5231, -0.0760,  2.6223,  1.9157,  2.9033,  0.1211, -0.1061,
         0.6001, -1.7923,  1.7988, -1.3788, -1.7105,  0.1941,  1.6227,  1.2291,
         0.5562,  1.9356, -0.7665,  1.6864,  0.1316,  1.5271,  3.0311, -0.0748,
        -2.1802, -1.1159,  0.7396, -0.2184,  1.1878,  1.6763, -1.9479, -0.7508,
         1.4421,  0.8478,  0.4187,  1.3319, -1.2282,  2.0423, -1.9388, -1.5317,
        -1.0179,  1.6385,  0.8261, -1.7613,  0.9579, -0.3701, -1.8082,  1.3161,
         0.1808,  1.5822, -2.0018,  1.0880,  0.6485,  1.5574,  0.3187, -1.6992,
         1.7380, -0.5354,  1.0756,  0.5897, -1.1188,  1.7005,  0.7067, -0.6838,
         0.8132, -0.2790,  2.1543,  0.5475,  0.0438,  3.2126,  1.0016,  2.1984,
         1.1711, -1.3954,  0.2343,  0.9718,  0.1050,  0.0598,  0.4093,  2.1276,
        -0.0687, -0.1722, -1.8237,  1.9616, -0.8970, -0.1906,  0.0222,  1.4738,
         2.8357, -0.6409, -0.9830,  2.0602, 