In [1]:
import os, time, torch, random
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch.nn as nn
import ipywidgets as widgets

In [2]:
trained_model_folder = "./out"
language_pair_staging_folder = "../data/magic_token_folder/"
only_copy_back_magic_token_embeddings = True

In [3]:
def load_model(model_path):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    return model

def load_tokenizer(tokenizer_path):
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    return tokenizer

In [4]:
def sample_token_numbers( model_name, token = "[GEN 1:1_a]" ):
    model_path = os.path.join( trained_model_folder, model_name )
    tokenizer = load_tokenizer(model_path)
    model = load_model(model_path)
    token_ids = tokenizer.encode( token )
    assert (len(token_ids) == 1)
    token_id = token_ids[0]
    input_embeddings = model.get_input_embeddings()
    token_embedding = input_embeddings(torch.LongTensor([token_id])).detach().numpy()[0]
    return token_embedding[:4]

Iterate through all the _step models and add up their embeddings for the magic tokens.

In [5]:
if os.path.exists(trained_model_folder) and os.path.isdir(trained_model_folder):
    # Get a list of all subfolders in language_pair_staging_folder
    subfolders = [folder for folder in os.listdir(trained_model_folder) if os.path.isdir(os.path.join(trained_model_folder, folder))]

    # Filter subfolders that end with "_model"
    model_folders = [folder for folder in subfolders if folder.endswith( "_step" ) ]

    # Print or use the list of model folders
    print("Folders ending with '_step':", model_folders)
else:
    print(f"The folder '{trained_model_folder}' does not exist or is not a directory.")

Folders ending with '_step': ['hebrew_model_step', 'bsb_model_step', 'greek_model_step', 'target_model_step']


In [6]:
embeddings_sum = {}
embeddings_count = {}

out = widgets.Output(layout={'border': '1px solid black'})
display(out)

for model_name in model_folders:
    print( f"Grabbing embeddings from {model_name}" )
    model_path = os.path.join( trained_model_folder, model_name )
    tokenizer = load_tokenizer(model_path)
    input_embeddings = load_model(model_path).get_input_embeddings()
    target_model = model_name.replace( "_model", "" ).replace( "_step", "" )
    train_file_path = os.path.join( language_pair_staging_folder, f"train_{target_model}.txt" )
    with open( train_file_path, "rt" ) as fin:
        for line in fin:
            line_tokenized = tokenizer.encode(line)
            #print( f"We got this for the line tokenized {line_tokenized}" )
            #now decode the token and see if it is one of our tokens if it starts with [
            for token_id in line_tokenized:
                token = tokenizer.decode([token_id])
                is_magic_token = token.startswith( "[" ) and token.endswith( "]" ) and token[-3] == "_"
                if( is_magic_token ):
                    if( random.random() > .97 ):
                        out.clear_output()
                        with out:
                            print( f"Getting embedding for {token}" )
                    #print( f"I see magic token {token}" )
                    #pull out the embedding for this specific token.
                    embedding = input_embeddings(torch.LongTensor([token_id])).detach().numpy()
                    #stash away the value so we can average it out.
                    if token not in embeddings_count:
                        embeddings_count[token] = 1
                        embeddings_sum[token] = embedding
                    else:
                        embeddings_count[token] += 1
                        embeddings_sum[token] += embedding
                        

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…

Grabbing embeddings from hebrew_model_step
Grabbing embeddings from bsb_model_step
Grabbing embeddings from greek_model_step
Grabbing embeddings from target_model_step


In [7]:
#compute the actual average of the embeddings.
embeddings_average = {}
for token,embedding_sum in embeddings_sum.items():
    embeddings_average[token] = embedding_sum/embeddings_count[token]

In [8]:
#print out a sample of the embeddings before updating them so we can see if they changed and how.
subfolders = [folder for folder in os.listdir(trained_model_folder) if os.path.isdir(os.path.join(trained_model_folder, folder))]
all_model_folders = [folder for folder in subfolders if folder.endswith( "_model" ) or folder.endswith( "_step" ) ]
for model_name in all_model_folders:
    print( f"{model_name}: {sample_token_numbers( model_name )}" )

hebrew_model_step: [ 0.00468181 -0.09646945  0.23584266 -0.0736285 ]
bsb_model_step: [ 0.01782305 -0.06168462  0.08194134 -0.02929468]
greek_model: [ 0.01380756 -0.06270172  0.05579207 -0.01744836]
greek_model_step: [ 0.04952674 -0.06531593  0.11450687 -0.05230689]
target_model_step: [ 0.00336675 -0.05822144  0.10511835 -0.02001085]
bsb_model: [ 0.01380756 -0.06270172  0.05579207 -0.01744836]
target_model: [ 0.01380756 -0.06270172  0.05579207 -0.01744836]
hebrew_model: [ 0.01821761 -0.03325709 -0.01717107 -0.01796533]


In [9]:
model_folders

['hebrew_model_step',
 'bsb_model_step',
 'greek_model_step',
 'target_model_step']

In [10]:
#now run back through all the models and inject the average embeddings.
for model_name in model_folders:
    
    #load the tokenizer and model
    model_path = os.path.join( trained_model_folder, model_name )    
    output_dir = model_path.replace( "_step", "" )

    model_to_splice = output_dir if only_copy_back_magic_token_embeddings else model_path
    
    print( f"Pulling in {model_to_splice}" )
    tokenizer = load_tokenizer(model_to_splice)
    model = load_model(model_to_splice)
    input_embeddings = model.get_input_embeddings()
    
    reconstructed_embedding_list = []
    #run through all the tokens
    for token_id in range(len(tokenizer)):
        token = tokenizer.decode([token_id])
        #and grab them from the average if it is in there.
        if token in embeddings_average:
            reconstructed_embedding_list.append( embeddings_average[token][0] )
        else:
            reconstructed_embedding_list.append( input_embeddings(torch.LongTensor([token_id])).detach().numpy()[0] )
            
    #construct it into some weights.
    new_weights = torch.FloatTensor(reconstructed_embedding_list)
    updated_embedding = nn.Embedding.from_pretrained(new_weights)
    #put it back into the model
    model.set_input_embeddings( updated_embedding )
    model.tie_weights()

    #save the model back to the non _step locations so the loop can begin again.
    print( f"Updating {output_dir}" )
    model.save_pretrained(output_dir)

Pulling in ./out/hebrew_model_step


  new_weights = torch.FloatTensor(reconstructed_embedding_list)


Updating ./out/hebrew_model
Pulling in ./out/bsb_model_step
Updating ./out/bsb_model
Pulling in ./out/greek_model_step
Updating ./out/greek_model
Pulling in ./out/target_model_step
Updating ./out/target_model


In [11]:
#now run through all the models and sample their numbers again to see how they changed.
subfolders = [folder for folder in os.listdir(trained_model_folder) if os.path.isdir(os.path.join(trained_model_folder, folder))]
all_model_folders = [folder for folder in subfolders if folder.endswith( "_model" ) or folder.endswith( "_step" ) ]
for model_name in all_model_folders:
    print( f"{model_name}: {sample_token_numbers( model_name )}" )

hebrew_model_step: [ 0.00468181 -0.09646945  0.23584266 -0.0736285 ]
bsb_model_step: [ 0.01782305 -0.06168462  0.08194134 -0.02929468]
greek_model: [ 0.01125243 -0.07907704  0.158892   -0.05146159]
greek_model_step: [ 0.04952674 -0.06531593  0.11450687 -0.05230689]
target_model_step: [ 0.00336675 -0.05822144  0.10511835 -0.02001085]
bsb_model: [ 0.01125243 -0.07907704  0.158892   -0.05146159]
target_model: [ 0.01125243 -0.07907704  0.158892   -0.05146159]
hebrew_model: [ 0.01125243 -0.07907704  0.158892   -0.05146159]


In [14]:
embeddings_average["[GEN 1:1_a]"][0][:4]

array([ 0.01125243, -0.07907704,  0.158892  , -0.05146159], dtype=float32)