# This notebook trains just the magic token embeddings.

In [1]:
import os, torch
from ipywidgets import Dropdown
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F
from torch.optim import SGD
import torch.nn as nn

In [2]:
trained_model_folder = "./out"
language_pair_staging_folder = "../data/magic_token_folder/"
num_magic_tokens = 2

In [3]:
subfolders = [folder for folder in os.listdir(trained_model_folder) if os.path.isdir(os.path.join(trained_model_folder, folder))]
model_folders = [folder for folder in subfolders if folder.endswith("_model") or folder.endswith( "_model_step" ) ]
selected_model_dropdown = Dropdown(options=model_folders)
print( "Select which model to train" )
display(selected_model_dropdown)

Select which model to train


Dropdown(options=('hebrew_model_step', 'bsb_model_step', 'greek_model', 'greek_model_step', 'target_model_step…

In [4]:
def load_model(model_path):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    return model


def load_tokenizer(tokenizer_path):
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    return tokenizer


def grab_line_from_file( file, index ):
    n = 0
    with open( file, "rt" ) as fin:
        for line in fin:
            if n == index:
                return line.strip()
            n += 1

def get_index_from_training_file( index ):
    language = selected_model_dropdown.value.replace( "_model", "" ).replace( "_step", "" )
    training_file = os.path.join( language_pair_staging_folder, 
                                 f"train_{language}.txt" )
    return grab_line_from_file( training_file, index )

In [5]:
tokenizer = load_tokenizer( os.path.join( trained_model_folder, selected_model_dropdown.value ) )

In [22]:
model = load_model( os.path.join( trained_model_folder, selected_model_dropdown.value ) )

In [None]:
verse_number = 0
while get_index_from_training_file( verse_number ) != None:
    verse = get_index_from_training_file( verse_number )
    print( f"\n\nWorking on verse {verse}" )
    #verse = "Eat your"
    tokenized = tokenizer( verse )
    
    # mask_for_starting_magic_tokens = torch.zeros(len(tokenizer))
    # mask_for_starting_magic_tokens[tokenized["input_ids"][0]] = 1
    # mask_for_starting_magic_tokens[tokenized["input_ids"][1]] = 1
    # mask_for_starting_magic_tokens = mask_for_starting_magic_tokens.unsqueeze(1)
                                   
    # control_weights = model.get_input_embeddings().weight
    # model.set_input_embeddings( nn.Embedding.from_pretrained( control_weights ) )
    # model.tie_weights()
    
    control_embeddings_weights = model.get_input_embeddings().weight.detach()
    control_embeddings = nn.Embedding.from_pretrained( control_embeddings_weights )
    
    # learning_rate = 0.01
    # optimizer = SGD( [
    #     model.get_input_embeddings()(torch.LongTensor([tokenized["input_ids"][0]])),
    #     model.get_input_embeddings()(torch.LongTensor([tokenized["input_ids"][1]]))], lr=learning_rate )
    
    # optimizer = SGD( [
    #     model.get_input_embeddings().weight[ tokenized["input_ids"][0] ],
    #     model.get_input_embeddings().weight[ tokenized["input_ids"][1] ],
    # ], lr=learning_rate )
    
    slope = -.0001
    
    for run in range( 100 ):
        model.zero_grad()
        print( f"\n{run}: ", end='' )

        found_problem = False
        
        for token_to_teach in range( num_magic_tokens, len(tokenized["input_ids"]) ):
    
            input_ids = torch.LongTensor(tokenized["input_ids"][:token_to_teach])
            #attention_mask = torch.LongTensor(tokenized["attention_mask"])
            correct_token = tokenized["input_ids"][token_to_teach]
            target_tensor = torch.zeros(len(tokenizer))
            target_tensor[correct_token] = 25.0
    
        
            inputs_embeds = control_embeddings( input_ids )
            inputs_embeds.requires_grad = True
            result = model.forward( inputs_embeds = inputs_embeds )#, attention_mask = attention_mask )
            loss = F.cross_entropy( result.logits[-1].unsqueeze(0), target_tensor.unsqueeze(0) )
            loss.backward()
        
            probs = F.softmax(result.logits[-1], dim=-1)
            #sampled_token_id = torch.multinomial(probs, 1).item()
            sampled_token_id = torch.argmax(probs, dim=-1).item()
            print(tokenizer.decode( [sampled_token_id] ), end='')
        
            # optimizer.step()
            # masked_grad = mask_for_starting_magic_tokens*model.get_input_embeddings().weight.grad
            # new_embeddings = nn.Embedding.from_pretrained( model.get_input_embeddings().weight + masked_grad * slope )
            # model.set_input_embeddings( new_embeddings )
            # model.tie_weights()
            #model.get_input_embeddings().weight += (masked_grad * slope)
            #control_weights += (masked_grad * slope)
    
            if sampled_token_id != correct_token:
                found_problem = True
                #print( f"\nbreaking because {tokenizer.decode([sampled_token_id])} != {tokenizer.decode([correct_token])}" )
                break
        
        new_inputs_embeds = inputs_embeds + (inputs_embeds.grad * slope)
    
        for token_i in range(num_magic_tokens):
            control_embeddings_weights[tokenized["input_ids"][token_i]] = new_inputs_embeds[token_i].detach()
        control_embeddings = nn.Embedding.from_pretrained( control_embeddings_weights )

        if not found_problem: break
    
    model.set_input_embeddings(control_embeddings)
    model.tie_weights()

    verse_number += 1
    
#correct_logits = 
# tokenized_torch = 
# output = model(tokenized)



Working on verse [GEN 1:1_a][GEN 1:1_b] In the beginning God created the heavens and the earth.

0: In the beginning God created the heavens and the earth.

Working on verse [GEN 1:2_a][GEN 1:2_b] Now the earth was formless and void, and darkness was over the surface of the deep. And the Spirit of God was hovering over the surface of the waters.

0: Now the earth was formless and void, and the
1: Now the earth was formless and void, and the
2: Now the earth was formless and void, and the
3:  the
4: Now the earth was formless and void, and the
5: Now the earth was formless and void, and the
6: Now the earth was formless and void, and the
7: Now the earth was formless and void, and the
8: Now the earth was formless and void, and darkness was over it
9: Now the earth was formless and void, and darkness was over it
10: Now the earth was formless and void, and darkness was over it
11: Now the earth was formless and void, and darkness was over it
12: Now the earth was formless and void, an

In [None]:
model.set_input_embeddings(control_embeddings)
model.tie_weights()

In [None]:
model.save_pretrained( os.path.join( trained_model_folder, selected_model_dropdown.value + "_prefixed" ) )

In [None]:
max_index = -1
max_value = -20000
for i in range(len(tokenizer)):
    if result.logits[-1][i].detach().numpy() > max_value:
        max_value = result.logits[-1][i].detach().numpy()
        max_index = i

In [None]:
max_value

In [None]:
tokenizer.decode([max_index])

In [17]:
model_output = model.generate( tokenizer.encode("[GEN 1:1_a][GEN 1:1_b]", return_tensors='pt'),do_sample=True,max_length=20,pad_token_id=model.config.eos_token_id, top_k=50, top_p=.95 )
tokenizer.decode( model_output[0] )

'[GEN 1:1_a] [GEN 1:1_b] Afterward, He looked up and saw a young man dressed in a white robe, as'

In [31]:
model_output = model.generate( tokenizer.encode("[GEN 1:1_a][GEN 1:1_b]", return_tensors='pt'),do_sample=True,max_length=200,pad_token_id=model.config.eos_token_id, top_k=1, top_p=.95 )
tokenizer.decode( model_output[0] )

'[GEN 1:1_a] [GEN 1:1_b] And the LORD said to Moses, “Tell Aaron, ‘Stretch out your hand over the waters of Meribah, for the LORD has sent a message concerning you.’” So Aaron stretched out his hand over the waters of Meribah, and the LORD sent a message over the waters of Meribah. waters never again spanned the Red Sea. Flavoringities of theiph, the flies, and the creatures that inhabited the ground covered the Red Sea. Flavoringities of theiph, the flies, and the creatures that inhabited the ground covered the Red Sea. Flavoringities of the creatures that inhabited the ground were so deep that the water had no flow. Flavoringities of the creatures that inhabited the ground were so deep that the water had no flow. Flavoringities of the creatures that inhabited the ground were so deep that the water had no flow. Flavoringities of the creatures that inhabited the ground were so deep that the water had no flow. Flavoringities of'

In [40]:
model_output = model.generate( tokenizer.encode("[GEN 1:3_a][GEN 1:3_b]", return_tensors='pt'),do_sample=True,max_length=200,pad_token_id=model.config.eos_token_id, top_k=1, top_p=.95 )
tokenizer.decode( model_output[0] )

'[GEN 1:3_a] [GEN 1:3_b] And the LORD said to Moses, “Tell Aaron, ‘Stretch out your hand over the waters of Meribah, for the LORD has sent a message concerning you.’” So Aaron stretched out his hand over the waters of Merib [MAT 18:14_b] etheless, the LORD did not listen to him, and the LORD did not destroy the people. Instead, the LORD stirred up the waters of Meribah and the people of the land went down to the springs.etheless, the LORD did not destroy them.asaptions of God came up, and the people of the land rejoiced.asaptions of God came up, and the LORD didrail over the people.asaptions of God came up, and the LORD did not destroy them.asaptions of God came up, and the LORD did not destroy them.asaptions of God came up, and the LORD did not destroy them.asaptions of God came up, and the LORD did not destroy them.asa'