## This model attempts to find n embeddings to make gpt2 produce scripture without modifying the model

In [14]:
import os, torch, pickle
from ipywidgets import Dropdown
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F
from torch.optim import SGD
import torch.nn as nn

In [59]:
trained_model_folder = "./out"
language_pair_staging_folder = "../data/magic_token_folder/"
num_magic_tokens = 1
trained_embeddings_pickle_file = os.path.join( language_pair_staging_folder, f"magic_tokens_size_{num_magic_tokens}.pickle" )

In [60]:
selected_langauge = "bsb"

In [78]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("No GPU available, using CPU.")


def load_model(model_path):
    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    return model


def load_tokenizer(tokenizer_path):
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    return tokenizer


def grab_line_from_file( file, index ):
    n = 0
    with open( file, "rt" ) as fin:
        for line in fin:
            if n == index:
                return line.strip()
            n += 1

def get_index_from_training_file( index ):
    training_file = os.path.join( language_pair_staging_folder, 
                                 f"train_{selected_langauge}.txt" )
    return grab_line_from_file( training_file, index )

def split_off_magic_tokens( verse ):
    magic_tokens = []
    start = 0
    iterator = 0
    while iterator < len(verse):
        if verse[iterator] == "]":
            magic_tokens.append( verse[start:iterator+1] )
            start = iterator+1
        iterator += 1
    return magic_tokens, verse[start:]

def change_number_of_magic_tokens( magic_tokens, new_number ):
    return [ magic_tokens[0][:-2] + chr(i + ord('a')) + "]" for i in range(new_number) ]

def save_to_pickle( learned_magic_tokens ):
    with open( trained_embeddings_pickle_file, "wb" ) as fout:
        pickle.dump( learned_magic_tokens, fout )

Using GPU: NVIDIA GeForce RTX 2080


In [79]:
split_off_magic_tokens( "[hi][how]are you today?" )

(['[hi]', '[how]'], 'are you today?')

In [80]:
change_number_of_magic_tokens( ["[GEN 1:1_a]","[GEN 1:1_b]"], 4 )

['[GEN 1:1_a]', '[GEN 1:1_b]', '[GEN 1:1_c]', '[GEN 1:1_d]']

In [81]:
tokenizer = load_tokenizer( "gpt2" )

In [83]:
model = load_model( "gpt2" )

In [84]:
zero_index = torch.LongTensor([0]).to(device)

In [85]:
#as I won't be useing an expanded tokenizer or model for this experiment we will need to store these embeddings in a dictionary
learned_magic_tokens = {}

In [86]:
if os.path.exists(trained_embeddings_pickle_file):
    with open(trained_embeddings_pickle_file, 'rb') as file:
        learned_magic_tokens = pickle.load(file)
        print( f"Loaded previouse training from {trained_embeddings_pickle_file}" )
else:
    print( f"Previouse training not found at {trained_embeddings_pickle_file}" )

Loaded previouse training from ../data/magic_token_folder/magic_tokens_size_1.pickle


In [None]:
#Init the iteration here so that we can start and stop the following cell without starting over.
verse_number = 0

In [None]:
while get_index_from_training_file( verse_number ) != None:
    verse_with_magic_tokens = get_index_from_training_file( verse_number )
    print( f"\n\nWorking on verse {verse_with_magic_tokens}" )
    #remove the magic tokens from the verse because the tokenizer won't know what to do with them.
    magic_tokens, verse = split_off_magic_tokens( verse_with_magic_tokens )

    #verse = "I love to eat chicken."
    #make it so that we can play with a different number of magic tokens then is in the training file.
    #make it so that we can play with a different number of magic tokens then is in the training file.
    magic_tokens = change_number_of_magic_tokens( magic_tokens, num_magic_tokens )
    tokenized = tokenizer( verse )
    
    #control_embeddings_weights = model.get_input_embeddings().weight.detach()
    #control_embeddings = nn.Embedding.from_pretrained( control_embeddings_weights )

    slope = -.05

    
    for run in range( 100 ):
        model.zero_grad()
        print( f"\n{run}: ", end='' )

        # if( magic_tokens[0] in learned_magic_tokens ):
        #     print( f"\n First couple learned values of magic token 0: {learned_magic_tokens[magic_tokens[0]][:5]}", end='' )

        found_problem = False
        
        for token_to_teach in range( len(tokenized["input_ids"]) ):
    
            input_ids = torch.LongTensor(tokenized["input_ids"][:token_to_teach])
            correct_token = tokenized["input_ids"][token_to_teach]
            target_tensor = torch.zeros(len(tokenizer)).to(device)
            target_tensor[correct_token] = 25.0

            inputs_embeds_list = []
            for magic_token in magic_tokens:
                if not magic_token in learned_magic_tokens:
                    learned_magic_tokens[magic_token] = torch.randn_like( model.get_input_embeddings()(zero_index)[0] )
                    learned_magic_tokens[magic_token].requires_grad = True
                    learned_magic_tokens[magic_token].retain_grad()
                inputs_embeds_list.append( learned_magic_tokens[magic_token] )
            for input_index in range(token_to_teach):
                inputs_embeds_list.append( model.get_input_embeddings()( torch.LongTensor([tokenized["input_ids"][input_index]]).to(device) )[0] )
                
            
            inputs_embeds = torch.stack(inputs_embeds_list, dim=0).to(device)
            #inputs_embeds.requires_grad = True
            #inputs_embeds.retain_grad()
            result = model.forward( inputs_embeds = inputs_embeds )
            loss = F.cross_entropy( result.logits[-1].unsqueeze(0), target_tensor.unsqueeze(0) )
            loss.backward()
        
            probs = F.softmax(result.logits[-1], dim=-1)
            #sampled_token_id = torch.multinomial(probs, 1).item()
            sampled_token_id = torch.argmax(probs, dim=-1).item()
            print(tokenizer.decode( [sampled_token_id] ), end='')
    
            if sampled_token_id != correct_token:
                found_problem = True
                print( f"\nbreaking because {tokenizer.decode([sampled_token_id])} != {tokenizer.decode([correct_token])}" )
                break
        
        # new_inputs_embeds = inputs_embeds + (inputs_embeds.grad * slope)

        # for magic_token_i in range(len(magic_tokens)):
        #     magic_token = magic_tokens[magic_token_i]
        #     learned_magic_tokens[magic_token] = new_inputs_embeds[magic_token_i].detach()

        for magic_token_i in range(len(magic_tokens)):
            magic_token = magic_tokens[magic_token_i]
            new_magic_token_tensor = (learned_magic_tokens[magic_token] + (learned_magic_tokens[magic_token].grad * slope)).detach()
            new_magic_token_tensor.requires_grad = True
            new_magic_token_tensor.retain_grad()
            learned_magic_tokens[magic_token] = new_magic_token_tensor

        if not found_problem: break

    save_to_pickle( learned_magic_tokens )
    verse_number += 1

save_to_pickle( learned_magic_tokens )



Working on verse [GEN 2:13_a][GEN 2:13_b] The name of the second river is Gihon; it winds through the whole land of Cush.

0:  The "
breaking because  " !=  name

1:  The.
breaking because . !=  name

2:  The.
breaking because . !=  name

3:  The.
breaking because . !=  name

4:  The.
breaking because . !=  name

5:  The

breaking because 
 !=  name

6:  The,
breaking because , !=  name

7:  The name,
breaking because , !=  of

8:  The name,
breaking because , !=  of

9:  The name,
breaking because , !=  of

10:  The name,
breaking because , !=  of

11:  The name,
breaking because , !=  of

12:  The name,
breaking because , !=  of

13:  The name,
breaking because , !=  of

14:  The name,
breaking because , !=  of

15:  The name,
breaking because , !=  of

16:  The name,
breaking because , !=  of

17:  The name,
breaking because , !=  of

18:  The name name
breaking because  name !=  of

19:  The name,
breaking because , !=  of

20:  The name of the the
breaking because  the !=  secon

In [None]:
save_to_pickle( learned_magic_tokens )

In [23]:
#test read
with open(trained_embeddings_pickle_file, 'rb') as file:
    loaded_tensor_dict = pickle.load(file)

In [24]:
loaded_tensor_dict

{'[GEN 1:1_a]': tensor([-1.0992e-01,  2.2229e-01, -7.2535e-01, -2.1512e-01,  5.6779e-01,
          1.0467e+00, -2.4451e+00, -2.0557e+00,  5.5300e-01,  5.1697e-01,
          1.1690e+00, -7.3763e-01, -1.6142e-01,  1.7500e+00, -2.2468e+00,
         -1.6448e+00,  4.5860e-02,  1.3547e+00, -3.2388e+00, -9.9583e-01,
          6.9453e-01, -7.3981e-01, -5.4273e-01, -2.5259e-01, -2.6112e-01,
         -7.7373e-01,  1.5115e+00, -7.8226e-01, -1.0458e+00, -1.3298e+00,
         -1.3137e+00, -1.0233e+00,  1.8834e+00, -1.0056e+00, -3.2838e-01,
         -1.6602e-01, -6.2908e-01,  3.4196e-03,  1.3483e+00, -6.9642e-02,
         -1.0222e+00,  1.6092e+00,  3.0540e-01,  1.4577e+00, -3.2722e-01,
          1.8582e+00, -1.8294e+00, -8.9691e-01, -8.2996e-01,  1.5754e+00,
          2.5428e-01,  1.5008e+00,  4.9382e-01,  4.8643e-01, -1.7270e-01,
         -7.4391e-01,  4.7919e-01, -7.7876e-01,  1.1722e+00, -5.9011e-01,
          1.1370e+00, -1.4926e+00,  8.8825e-02, -5.2010e-01,  9.0567e-02,
          7.4316e-02,  