In [1]:
import argparse
import time
import torch
print(torch.__version__)
from Models import get_model
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import dill as pickle
import time
import math
import numpy as np
import multiprocessing as mp
import random
import string
import sys
import os
import whoosh, glob, time, pickle
import whoosh.fields as wf
from whoosh.qparser import QueryParser
from whoosh import index
import threading
from whoosh import filedb
from whoosh.filedb.filestore import FileStorage
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import time
from InMemorySearch import *
from client import *
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from logger import Logger
logger = Logger('./logs')

1.0.0


  from ._conv import register_converters as _register_converters


In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=2

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=2


### Setting up the parameters. If you are running the notebook for the first time set load_vocab parameter to None. After the model starts running it will put the vocabulary in the vocab folder. If you want to use the model in the paper set relevance_training variable to 1, otherwise the model would be a simple transformer. 

In [3]:
parser = argparse.ArgumentParser()
#If we are working on small dataset
small = 0
#If we want to activate relevance based training
relevance_training = 1
dst = 'vocab'
if small==1:
    parser.add_argument('-src_data', type=str, default='data/italian_small.txt')
    parser.add_argument('-trg_data', type=str, default='data/english_small.txt')
    parser.add_argument('-trg_data_retrieval', type=str, default='data/english_retrieval.txt')

else:
    parser.add_argument('-src_data', type=str, default='data/italian.txt')
    parser.add_argument('-trg_data', type=str, default='data/english.txt')  
    parser.add_argument('-trg_data_retrieval', type=str, default='data/LATIMESTEXT2.txt')
    parser.add_argument('-rm_data', type=str, default='data/english_rm_dir.txt') ##

parser.add_argument('-src_lang', type=str, default='it')
parser.add_argument('-trg_lang', type=str, default='en')
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=10)
parser.add_argument('-d_model', type=int, default=200)
parser.add_argument('-n_layers', type=int, default=6)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-batchsize', type=int, default=1000)
parser.add_argument('-printevery', type=int, default=5)
#if you are running for the first time set load_vocab to None. A vocabulary would be created in the vocab directory. 
parser.add_argument('-load_vocab', type=str, default='clir_it_en')

if relevance_training == 1:
    #if there exists a model_weights file 
    my_file = Path("weights/model_weights")
    if my_file.is_file():
        parser.add_argument('-load_weights', type=str, default='weights')
    else:
        parser.add_argument('-load_weights', type=str, default=None)
    parser.add_argument('-lr', type=int, default=0.01)
else: 
    my_file = Path("weights/model_weights")
    if my_file.is_file():
        parser.add_argument('-load_weights', type=str, default='weights')
    else:
        parser.add_argument('-load_weights', type=str, default=None)         
    parser.add_argument('-lr', type=int, default=0.0001)
    
parser.add_argument('-max_strlen', type=int, default=80)
parser.add_argument('-checkpoint', type=int, default=5)

opt = parser.parse_args(args=[])


In [4]:
def tokenizer(text):  # create a tokenizer function
    return text.split()
    
def create_fields(opt):    
    print("loading tokenizers...") 
    TRG = data.Field(lower=True, tokenize=tokenizer, init_token='<sos>', eos_token='<eos>')
    TRG_REL = data.Field(lower=True, tokenize=tokenizer, init_token='<sos>', eos_token='<eos>')    
    SRC = data.Field(lower=True, tokenize=tokenizer)   
    SRC = pickle.load(open(f'{opt.load_vocab}/SRC.pkl', 'rb'))
    TRG = pickle.load(open(f'{opt.load_vocab}/TRG.pkl', 'rb'))
    TRG_REL = pickle.load(open(f'{opt.load_vocab}/TRG.pkl', 'rb'))
    return(SRC, TRG, TRG_REL)


def create_dataset(opt, SRC, TRG, TRG_REL):
    """
    This function is used to construct the training batches from the input. 
    opt: parameters 
    SRC: source data for machine translation. 
    TRG: target data for machine translation. 
    TRG_REL: retrieval data for machine translation. This data is constructed apriori using TRG sentences as 
    queries to the retrieval corpus and retrieving a passage from the most relevant document. we call it relevance 
    data 
    """
    
    print("creating dataset and iterator... ")
    translation_data = [line for line in opt.trg_data] ### 
    relevance_data = [line for line in open(opt.rm_data)] ### relevance data is loaded from rm_data file.     
    
    raw_data = {'src' : [line.strip() for line in opt.src_data], 'trg': translation_data, 'trg_rel': relevance_data} ###  
    #raw data retrieval is the whole retrieval corpus. we use it to estimate tf-idf of all the words in the 
    #retrieval corpus. 
    
    raw_data_retrieval = {'trg': [line for line in opt.trg_data_retrieval]}
    df = pd.DataFrame(raw_data, columns=["src", "trg", "trg_rel"])
    
    #using the retrieval data to compute idf of every token in the retrieval corpus 
    vectorizer = TfidfVectorizer(use_idf=True)
    vectorizer.fit_transform(raw_data['trg'] + raw_data_retrieval['trg'])
    tokens = vectorizer.get_feature_names()
    idf_values = vectorizer.idf_
    opt.idf_dict = {}
    for i in range(len(tokens)):
        opt.idf_dict.setdefault(tokens[i],idf_values[i]) 
    
    #chopping off sentences with maximum string length parameter
    mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    df = df.loc[mask]
    
    #a dataframe is required for torchtext. we are creating a temporary one. each row of the dataframe would 
    #consist of source sentence, target sentence, and relevant passage retrieved using the target sentence. 
    df.to_csv("translate_transformer_temp.csv", index=False)
    
    data_fields = [('src', SRC), ('trg', TRG) , ('trg_rel', TRG_REL)]
    
    train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
    #creating training batches. 
    train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=batch_size_fn, train=True, shuffle=True)
    #the following file is not necessary anymore.   
    os.remove('translate_transformer_temp.csv')   
    
    
    if opt.load_vocab is None:
        #if we have not built the vocabulary, we will have to do it here. We are integrating both the retrieval 
        #copus as well as the translation corpus. 
        print("creating dataset for retrieval corpus... ")
        raw_data = {'trg': [line for line in opt.trg_data_retrieval]}
        df = pd.DataFrame(raw_data, columns=["trg"])
        mask = (df['trg'].str.count(' ') > 1)
        df = df.loc[mask]
        df.to_csv("translate_transformer_retrieval_temp.csv", index=False)
        data_fields = [('trg', TRG)]
        train_retrieval = data.TabularDataset('./translate_transformer_retrieval_temp.csv', format='csv', fields=data_fields)
        os.remove('translate_transformer_retrieval_temp.csv')    

        print("building vocabulary for both europarl and retrieval corpus")

        SRC.build_vocab(train)
        TRG.build_vocab(train, train_retrieval)

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']
    
    opt.train_len = get_len(train_iter)
    return train_iter



In [5]:
opt.device = 0 if opt.no_cuda is False else -1
if opt.device == 0:
    assert torch.cuda.is_available()
print(opt.device)
read_data(opt)

SRC, TRG, TRG_REL = create_fields(opt)
opt.train = create_dataset(opt, SRC, TRG, TRG_REL)

#TRG = create_retrieval_vocabulary(opt, TRG)

if opt.checkpoint > 0:
    print("model weights will be saved every %d minutes and at end of epoch to directory weights/"%(opt.checkpoint))
    
if opt.load_vocab is None:
    pickle.dump(SRC, open(f'{dst}/SRC.pkl', 'wb'))
    pickle.dump(TRG, open(f'{dst}/TRG.pkl', 'wb'))

0
loading tokenizers...
creating dataset and iterator... 


The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


model weights will be saved every 5 minutes and at end of epoch to directory weights/


In [6]:
CONTEXT_SIZE = 2            
model = get_model(opt, len(SRC.vocab), len(TRG.vocab), CONTEXT_SIZE)
opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
opt.optimizer_wemb = torch.optim.Adam(model.parameters(), lr = 0.000001, betas=(0.9, 0.98), eps=1e-9)
if opt.SGDR == True:
    opt.sched = CosineWithRestarts(opt.optimizer, T_max=opt.train_len)
    
print(TRG.vocab.itos[0])
print(len(TRG.vocab))
opt.idf_dict["<sos>"] = 1
opt.idf_dict["<eos>"] = 1
class_weights = []
for i in range(len(TRG.vocab)):
    if TRG.vocab.itos[i] in opt.idf_dict:
        class_weights.append(opt.idf_dict[TRG.vocab.itos[i]])
    else:
        class_weights.append(0.0001)
class_weights = torch.FloatTensor(class_weights)

loading pretrained weights...
<unk>
240229


In [7]:
print(TRG.vocab.itos[3])

<eos>


In [None]:
from torch import nn 
import subprocess
opt.checkpoint = 5
nll_loss = nn.NLLLoss() # loss function
            
def train_model(model, opt, SRC, TRG):
    torch.cuda.empty_cache()
    best_mAP = 0.10
    print("training model...")
    opt.idf_dict["<sos>"] = 1
    opt.idf_dict["<eos>"] = 1
    
    #class weight indicates the weight of each term in calculating the loss function 
    #we have set it as the idf of the terms. 
    class_weights = []
    for i in range(len(TRG.vocab)):
        if TRG.vocab.itos[i] in opt.idf_dict:
            class_weights.append(opt.idf_dict[TRG.vocab.itos[i]])            
        else:
            class_weights.append(0.0001)
    
    class_weights = torch.FloatTensor(class_weights)
    model.train()
    class_weights = class_weights.cuda()
    
    start = time.time()
    #If we are checkpointing model after a certain period of time, we will use cptime to keep track of time
    if opt.checkpoint > 0:
        cptime = time.time()
        checkpointing_step = 0
        
    #the variable to keep track of global step 
    #print("number of batches: ".format(len(opt.train)))
    
    for epoch in range(opt.epochs):
        print("beginning epoch: {}".format(epoch))
        step = 0
        translation_loss = 0
        embedding_loss = 0
                
        for i, batch in enumerate(opt.train):
            #in each batch we have three inputs src, trg and trg_retrieval. This is the part where it is different 
            #from traditional machine translation. trg_retrieval comes from the retrieval corpus. please refer to the 
            #paper that for each sentence in trg we retrieve from a corpus and add it to the parallel data. 
            torch.cuda.empty_cache()
            start = time.clock()
            src = batch.src.transpose(0,1)
            trg = batch.trg.transpose(0,1)    
            trg_rel = batch.trg_rel.transpose(0,1) 
            #the length of trg is pretty small and that's why we are adding it three times and finally adding it with 
            #trg_rel. trg_rel comes from the retrieval corpus. This is why and how we construct tt. 
            tt = torch.cat((trg[: , :-1], trg[: , :-1], trg[: , :-1], trg_rel), 1).numpy()
            
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)            
            src_mask = src_mask.cuda()            
            trg_mask = trg_mask.cuda()
            src = src.cuda() 
            trg_input = trg_input.cuda()  
            data_context_final = []
            data_target_final = []
            #now creating the training batches for word embedding on the fly. 
            for element in tt: 
                #removing the index of some unnecessay tokens
                indices = [i for i, x in enumerate(element) if x == 0 or x == 1 or x == 2 or x ==3]
                element = np.delete(element, indices)
                np.random.shuffle(element)
                corpus_text = torch.tensor(element)
                #creating word embedding data. The pivot will be at index i. 
                for i in range(CONTEXT_SIZE, len(corpus_text) - CONTEXT_SIZE):
                    #at first create the context 
                    data_context = []
                    data_target = []
                    for j in range(CONTEXT_SIZE):
                        data_context.append(corpus_text[i - CONTEXT_SIZE + j])
                    for j in range(1, CONTEXT_SIZE + 1):
                        data_context.append(corpus_text[i + j])
                    #now create the pivot or the target 
                    data_target.append(corpus_text[i])
                    #add context and pivot to create a batch
                    data_context_final.append(torch.LongTensor(data_context))
                    data_target_final.append(torch.LongTensor(data_target))
                    
            #Now we are training word embedding. 
            offset = 1000
            number_of_batches = int(len(data_context_final) / offset)
            index = 0
            loss_wemb = 0
            for batch in range(number_of_batches):
                opt.optimizer_wemb.zero_grad()
                #creating a batch of 1000 context, target pairs 
                data_context_final_temp = torch.stack(data_context_final[index:index+offset]).cuda()
                data_target_final_temp = torch.stack(data_target_final[index:index+offset]).cuda()            
                preds_emb = model(src, trg_input, src_mask, trg_mask, data_context_final_temp)
                #computing word embedding loss for 1000 data points for the word embedding task  
                loss_wemb_temp = F.cross_entropy(preds_emb.view(-1, preds_emb.size(-1)), data_target_final_temp.contiguous().view(-1).cuda())
                #make word embedding loss less impactful 
                loss_wemb_temp/= 10
                loss_wemb+=loss_wemb_temp.item()
                index+=1
                loss_wemb_temp.backward()
                opt.optimizer_wemb.step()
                
            #Now using the model to translate src to trg_input. 
            preds = model(src, trg_input, src_mask, trg_mask, 2)
            #The original translations 
            ys = trg[:, 1:].contiguous().view(-1).cuda()
            opt.optimizer.zero_grad()
            #computing cross-entropy loss for the translation task 
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, weight=class_weights, ignore_index=opt.trg_pad)              
            loss.backward()

            translation_loss+= loss.item()
            embedding_loss+= loss_wemb/index

            opt.optimizer.step()
            
            if opt.SGDR == True: 
                opt.sched.step()
            
            print(str(step) + "\tstep loss nmt\t" + str(loss.item()) + "\tstep loss embedding\t" + str(loss_wemb/index))                
            step+=1
                
            #We are saving the model every five minutes. IR loss is costly to evaluate, hence we use this mechanism.
            if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1:
                #saving the model in a directory 
                torch.save(model.state_dict(), 'weights/model_weights')                            
                #loading the model to evaluate on IR task. 
                p = subprocess.Popen('CUDA_VISIBLE_DEVICES=3 python translate_validation.py', shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
                line = p.stdout.readlines()[0]
                #reading the evaluation from the IR task. specifically focusing on mAP. 
                mAP = float(line.decode("utf-8").strip().split()[1])
                print(str(checkpointing_step) + "\t"  + str(mAP) + "\tstep loss nmt\t" + str(translation_loss/step) + "\tstep loss embedding\t" + str(embedding_loss/step) + "\tcheckpoint info")                
                #if we have a better map compared to the current map, we save the model. 
                if mAP > best_mAP:
                    best_mAP = mAP
                    #The best model would be saved in weights directory if we get a better map on the validation set
                    torch.save(model.state_dict(), 'weights/model_weights_best_validation')
                cptime = time.time()
                info = { 'loss': loss.item(), 'map': mAP}
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, checkpointing_step+1)
                checkpointing_step+=1
            
        print(str(opt.epoch) + "\tepoch loss nmt\t" + str(translation_loss/step) + "\tepoch loss embedding\t" + str(embedding_loss/step) + "\tepoch info")                
                
        #Finally saving a model after an epoch. 
        torch.save(model.state_dict(), 'weights/model_weights_' + str(epoch))

In [None]:
#torch.cuda.empty_cache()
model = model.cuda()
train_model(model, opt, SRC, TRG)

training model...
beginning epoch: 0
0	step loss nmt	3.9221818447113037	step loss embedding	1.2294775992631912
1	step loss nmt	16.530879974365234	step loss embedding	1.2403833184923445
2	step loss nmt	31.954986572265625	step loss embedding	1.2416691184043884
3	step loss nmt	25.41067123413086	step loss embedding	1.2426793575286865
4	step loss nmt	18.872697830200195	step loss embedding	1.2458011209964752
5	step loss nmt	18.58994483947754	step loss embedding	1.2513601660728455
6	step loss nmt	21.455183029174805	step loss embedding	1.2514421790838242
7	step loss nmt	20.14837646484375	step loss embedding	1.245383381843567
8	step loss nmt	19.965816497802734	step loss embedding	1.2517280578613281
9	step loss nmt	18.539390563964844	step loss embedding	1.2480195210530207
10	step loss nmt	17.99416160583496	step loss embedding	1.2518243619373866
11	step loss nmt	18.297880172729492	step loss embedding	1.248568868637085
12	step loss nmt	17.861736297607422	step loss embedding	1.2507413983345033
13	s

108	step loss nmt	9.96865177154541	step loss embedding	1.3064419407593577
109	step loss nmt	9.93358039855957	step loss embedding	1.3000219662984211
110	step loss nmt	10.18037223815918	step loss embedding	1.3111427545547485
111	step loss nmt	10.347564697265625	step loss embedding	1.3358052521944046
112	step loss nmt	10.064628601074219	step loss embedding	1.2856353861945016
113	step loss nmt	9.831003189086914	step loss embedding	1.2958697199821472
114	step loss nmt	10.234359741210938	step loss embedding	1.2918894290924072
115	step loss nmt	9.895061492919922	step loss embedding	1.2836664219697316
116	step loss nmt	11.070008277893066	step loss embedding	1.3028392493724823
117	step loss nmt	9.843363761901855	step loss embedding	1.3119183381398518
118	step loss nmt	9.878853797912598	step loss embedding	1.3277562686375208
119	step loss nmt	9.655867576599121	step loss embedding	1.3057878255844115
120	step loss nmt	10.056974411010742	step loss embedding	1.2892587184906006
121	step loss nmt	9.59

216	step loss nmt	9.24958324432373	step loss embedding	1.2648679993369363
217	step loss nmt	9.236075401306152	step loss embedding	1.2664123406777015
218	step loss nmt	9.618104934692383	step loss embedding	1.3248342672983806
219	step loss nmt	8.247587203979492	step loss embedding	1.3430165290832519
220	step loss nmt	8.887136459350586	step loss embedding	1.2844339098249162
221	step loss nmt	9.481388092041016	step loss embedding	1.269804798066616
222	step loss nmt	9.146309852600098	step loss embedding	1.2805055873147373
223	step loss nmt	9.503437995910645	step loss embedding	1.2905791997909546
224	step loss nmt	8.901814460754395	step loss embedding	1.2569333579805162
225	step loss nmt	9.160184860229492	step loss embedding	1.2715411828114436
226	step loss nmt	9.30960464477539	step loss embedding	1.3162663221359252
227	step loss nmt	8.710799217224121	step loss embedding	1.3486074341668024
228	step loss nmt	9.60727596282959	step loss embedding	1.2651828697749548
229	step loss nmt	8.831098556

324	step loss nmt	8.772494316101074	step loss embedding	1.3091284831364949
325	step loss nmt	8.87333869934082	step loss embedding	1.3382326066493988
326	step loss nmt	8.963475227355957	step loss embedding	1.3123536556959152
327	step loss nmt	8.73640251159668	step loss embedding	1.3226934969425201
328	step loss nmt	8.621338844299316	step loss embedding	1.297798901796341
329	step loss nmt	8.557950973510742	step loss embedding	1.2968035650253296
330	step loss nmt	8.673569679260254	step loss embedding	1.3587224781513214
331	step loss nmt	8.526082038879395	step loss embedding	1.2585987011591593
332	step loss nmt	8.859979629516602	step loss embedding	1.3613766431808472
333	step loss nmt	8.682827949523926	step loss embedding	1.3200665354728698
334	step loss nmt	8.684432029724121	step loss embedding	1.2904251151614718
335	step loss nmt	8.741538047790527	step loss embedding	1.265244537386401
336	step loss nmt	8.671205520629883	step loss embedding	1.2889096112478347
337	step loss nmt	8.894404411

432	step loss nmt	8.770748138427734	step loss embedding	1.3126161098480225
433	step loss nmt	8.430312156677246	step loss embedding	1.3369815826416016
434	step loss nmt	8.261602401733398	step loss embedding	1.3307399253050487
435	step loss nmt	8.231709480285645	step loss embedding	1.4090728878974914
436	step loss nmt	8.438840866088867	step loss embedding	1.250011682510376
437	step loss nmt	8.345796585083008	step loss embedding	1.3035880785721998
438	step loss nmt	8.45858097076416	step loss embedding	1.3122703954577446
439	step loss nmt	8.265209197998047	step loss embedding	1.423964023590088
440	step loss nmt	8.21714973449707	step loss embedding	1.2701641504581158
441	step loss nmt	8.561320304870605	step loss embedding	1.2910556684840808
442	step loss nmt	8.241265296936035	step loss embedding	1.2703839865597812
443	step loss nmt	8.433807373046875	step loss embedding	1.324491646554735
444	step loss nmt	8.524564743041992	step loss embedding	1.3784843921661376
445	step loss nmt	8.2643108367

541	step loss nmt	8.393452644348145	step loss embedding	1.3931622877717018
542	step loss nmt	8.167985916137695	step loss embedding	1.444692571957906
543	step loss nmt	8.561407089233398	step loss embedding	1.531277970834212
9	0.002	step loss nmt	9.750049327664515	step loss embedding	1.3090200010363835	checkpoint info
544	step loss nmt	8.262042045593262	step loss embedding	1.3979695558547973
545	step loss nmt	8.148853302001953	step loss embedding	1.390040103594462
546	step loss nmt	7.8985443115234375	step loss embedding	1.4012769209711176
547	step loss nmt	8.14887809753418	step loss embedding	1.370961274419512
548	step loss nmt	7.928107261657715	step loss embedding	1.3848712359155928
549	step loss nmt	8.02210807800293	step loss embedding	1.4420182779431343
550	step loss nmt	8.095084190368652	step loss embedding	1.406139850616455
551	step loss nmt	8.234695434570312	step loss embedding	1.4045137564341228
552	step loss nmt	8.328733444213867	step loss embedding	1.3476951718330383
553	step lo

647	step loss nmt	8.279932022094727	step loss embedding	1.4129955768585205
648	step loss nmt	8.536548614501953	step loss embedding	1.404591312011083
649	step loss nmt	8.188941955566406	step loss embedding	1.4586079359054565
650	step loss nmt	8.270215034484863	step loss embedding	1.4213817516962688
651	step loss nmt	7.8805317878723145	step loss embedding	1.3995669340265209
652	step loss nmt	8.029702186584473	step loss embedding	1.4208976200648717
653	step loss nmt	8.534364700317383	step loss embedding	1.4209596102054303
654	step loss nmt	7.982872486114502	step loss embedding	1.4565027554829915
655	step loss nmt	8.001521110534668	step loss embedding	1.3666782287450938
656	step loss nmt	8.258076667785645	step loss embedding	1.4237260073423386
657	step loss nmt	8.357898712158203	step loss embedding	1.4119636160986764
658	step loss nmt	8.17592716217041	step loss embedding	1.4295515843800135
659	step loss nmt	8.415766716003418	step loss embedding	1.4446991582711537
660	step loss nmt	8.289507

755	step loss nmt	8.15716552734375	step loss embedding	1.440795622373882
756	step loss nmt	7.933178424835205	step loss embedding	1.398111460160236
757	step loss nmt	8.295804977416992	step loss embedding	1.41898708542188
758	step loss nmt	8.32403564453125	step loss embedding	1.5023135571252733
759	step loss nmt	8.2039794921875	step loss embedding	1.357763673948205
760	step loss nmt	8.401471138000488	step loss embedding	1.4792663156986237
761	step loss nmt	8.393186569213867	step loss embedding	1.484387137673118
762	step loss nmt	8.499458312988281	step loss embedding	1.48756742477417
763	step loss nmt	8.359170913696289	step loss embedding	1.4869344711303711
764	step loss nmt	8.779352188110352	step loss embedding	1.6129263043403625
765	step loss nmt	8.13813304901123	step loss embedding	1.451699435710907
766	step loss nmt	8.185457229614258	step loss embedding	1.4456054599661576
767	step loss nmt	8.364718437194824	step loss embedding	1.376005719689762
768	step loss nmt	8.336222648620605	step

863	step loss nmt	8.018876075744629	step loss embedding	1.4495269702031062
864	step loss nmt	8.227340698242188	step loss embedding	1.4875139756636186
865	step loss nmt	8.411810874938965	step loss embedding	1.5504037290811539
866	step loss nmt	8.322837829589844	step loss embedding	1.4217397616459773
867	step loss nmt	8.955352783203125	step loss embedding	1.5021270513534546
868	step loss nmt	8.30423355102539	step loss embedding	1.5300192013382912
869	step loss nmt	7.984251022338867	step loss embedding	1.4658530511354144
870	step loss nmt	7.719394207000732	step loss embedding	1.5263317482812064
871	step loss nmt	8.383694648742676	step loss embedding	1.4675649404525757
872	step loss nmt	8.209501266479492	step loss embedding	1.599229633808136
873	step loss nmt	7.9609694480896	step loss embedding	1.4697182814280192
874	step loss nmt	7.967780113220215	step loss embedding	1.4449505439171424
875	step loss nmt	7.956146717071533	step loss embedding	1.4614281555016835
876	step loss nmt	7.859156608

971	step loss nmt	8.064481735229492	step loss embedding	1.5069416066010792
972	step loss nmt	8.153314590454102	step loss embedding	1.5496990463950417
973	step loss nmt	7.9908528327941895	step loss embedding	1.4924252911617881
974	step loss nmt	8.25925350189209	step loss embedding	1.5176180998484294
975	step loss nmt	8.119993209838867	step loss embedding	1.6175112837836856
976	step loss nmt	8.297161102294922	step loss embedding	1.5576132237911224
977	step loss nmt	7.892335891723633	step loss embedding	1.4957052624743918
978	step loss nmt	8.235607147216797	step loss embedding	1.5355372862382368
979	step loss nmt	7.770693302154541	step loss embedding	1.40798964848121
980	step loss nmt	8.135642051696777	step loss embedding	1.5195305517741613
981	step loss nmt	8.662096977233887	step loss embedding	1.495648219035222
982	step loss nmt	8.159808158874512	step loss embedding	1.6792666117350261
983	step loss nmt	8.30226993560791	step loss embedding	1.5704747200012208
984	step loss nmt	8.362898826

1078	step loss nmt	7.936362266540527	step loss embedding	1.5400889244946567
1079	step loss nmt	8.271265983581543	step loss embedding	1.4987137129432277
1080	step loss nmt	8.495806694030762	step loss embedding	1.6084717412789662
1081	step loss nmt	7.804699897766113	step loss embedding	1.7262080343146073
1082	step loss nmt	8.030440330505371	step loss embedding	1.5561238960786299
1083	step loss nmt	7.977550506591797	step loss embedding	1.5761230973636402
1084	step loss nmt	8.187849044799805	step loss embedding	1.5184214735031127
1085	step loss nmt	8.143819808959961	step loss embedding	1.5467039129950784
1086	step loss nmt	8.174529075622559	step loss embedding	1.6753055175145468
1087	step loss nmt	8.589444160461426	step loss embedding	1.542361332820012
1088	step loss nmt	8.256734848022461	step loss embedding	1.5584170487191942
1089	step loss nmt	7.64529275894165	step loss embedding	1.5064827757222312
1090	step loss nmt	8.034367561340332	step loss embedding	1.4877889251708984
1091	step loss

1184	step loss nmt	7.665289878845215	step loss embedding	1.652049101316012
1185	step loss nmt	8.15241813659668	step loss embedding	1.7701002870287215
1186	step loss nmt	8.179254531860352	step loss embedding	1.6599579794066293
1187	step loss nmt	8.067743301391602	step loss embedding	1.8743577553675725
1188	step loss nmt	8.028818130493164	step loss embedding	1.6269078354040782
1189	step loss nmt	8.215596199035645	step loss embedding	1.670122572353908
1190	step loss nmt	8.50704288482666	step loss embedding	1.5903350909550984
1191	step loss nmt	8.37370491027832	step loss embedding	1.6732035130262375
1192	step loss nmt	7.768993377685547	step loss embedding	1.5547379702329636
1193	step loss nmt	7.7145891189575195	step loss embedding	1.5977103521949367
1194	step loss nmt	8.165642738342285	step loss embedding	1.8077661440922663
1195	step loss nmt	7.9253153800964355	step loss embedding	1.543537667819432
1196	step loss nmt	8.259888648986816	step loss embedding	1.6721430165427071
1197	step loss n

1290	step loss nmt	7.767173767089844	step loss embedding	1.7785019212298923
1291	step loss nmt	7.8231892585754395	step loss embedding	1.6737689971923828
1292	step loss nmt	8.193333625793457	step loss embedding	1.616597511551597
1293	step loss nmt	8.317727088928223	step loss embedding	1.786789337793986
1294	step loss nmt	8.160167694091797	step loss embedding	1.7601914778351784
1295	step loss nmt	8.17238712310791	step loss embedding	1.7887331777148776
1296	step loss nmt	7.875824451446533	step loss embedding	1.6675771951675415
1297	step loss nmt	8.453267097473145	step loss embedding	1.6468824744224548
1298	step loss nmt	7.979579925537109	step loss embedding	1.6916144013404846
1299	step loss nmt	8.052931785583496	step loss embedding	1.7062218691173352
1300	step loss nmt	8.077630996704102	step loss embedding	1.6739073594411213
1301	step loss nmt	7.944411754608154	step loss embedding	1.7655051549275715
1302	step loss nmt	7.7929911613464355	step loss embedding	1.6811213748795646
1303	step los

1396	step loss nmt	7.82072639465332	step loss embedding	1.6990230977535248
1397	step loss nmt	8.542651176452637	step loss embedding	1.9213712871074677
1398	step loss nmt	7.9615654945373535	step loss embedding	1.9526656419038773
1399	step loss nmt	8.072347640991211	step loss embedding	1.6813047130902607
1400	step loss nmt	8.17883586883545	step loss embedding	1.7336328396430383
1401	step loss nmt	7.605926036834717	step loss embedding	1.730335159735246
1402	step loss nmt	8.487552642822266	step loss embedding	1.8692001178860664
1403	step loss nmt	8.281033515930176	step loss embedding	1.8241069167852402
1404	step loss nmt	8.158361434936523	step loss embedding	1.7373887697855632
1405	step loss nmt	7.900831699371338	step loss embedding	1.880316936969757
1406	step loss nmt	8.102339744567871	step loss embedding	1.779285579919815
1407	step loss nmt	8.185669898986816	step loss embedding	1.6874945759773254
1408	step loss nmt	8.146728515625	step loss embedding	1.8303149491548538
1409	step loss nmt	

1502	step loss nmt	7.562668323516846	step loss embedding	1.7666128448077612
1503	step loss nmt	7.98879337310791	step loss embedding	1.7766967795111916
1504	step loss nmt	7.775434494018555	step loss embedding	1.7947142124176025
1505	step loss nmt	8.06036376953125	step loss embedding	1.7888538922582353
1506	step loss nmt	8.2923583984375	step loss embedding	1.7175950780510902
1507	step loss nmt	7.440420627593994	step loss embedding	1.7607073850101895
1508	step loss nmt	8.216815948486328	step loss embedding	1.7841481897566054
1509	step loss nmt	7.673227310180664	step loss embedding	1.7846498250961305
1510	step loss nmt	7.953347682952881	step loss embedding	1.936900602446662
1511	step loss nmt	7.6263203620910645	step loss embedding	1.7721923391024272
1512	step loss nmt	7.966670036315918	step loss embedding	1.8074785619974136
1513	step loss nmt	8.049036026000977	step loss embedding	1.7448932917221733
1514	step loss nmt	7.901585102081299	step loss embedding	1.765007632119315
1515	step loss nm

1608	step loss nmt	7.53828763961792	step loss embedding	1.9781302783800208
1609	step loss nmt	8.182764053344727	step loss embedding	1.9905762275060017
1610	step loss nmt	8.405633926391602	step loss embedding	1.6839773654937744
1611	step loss nmt	8.443568229675293	step loss embedding	1.9595128695170085
1612	step loss nmt	7.461755752563477	step loss embedding	1.8649789293607075
1613	step loss nmt	8.30314826965332	step loss embedding	1.8762660653967607
1614	step loss nmt	7.7354912757873535	step loss embedding	1.7305629551410675
1615	step loss nmt	7.072229385375977	step loss embedding	1.7986708207008166
1616	step loss nmt	7.883539199829102	step loss embedding	1.8793521192338731
30	0.0041	step loss nmt	8.654146939202747	step loss embedding	1.512918952291552	checkpoint info
1617	step loss nmt	8.036482810974121	step loss embedding	1.8787778615951538
1618	step loss nmt	8.141866683959961	step loss embedding	2.0376440286636353
1619	step loss nmt	7.945870399475098	step loss embedding	1.9002162317

1714	step loss nmt	7.866786003112793	step loss embedding	2.241247038046519
1715	step loss nmt	7.8343377113342285	step loss embedding	1.8575426765850611
32	0.0024	step loss nmt	8.611480337061804	step loss embedding	1.5378752864738954	checkpoint info
1716	step loss nmt	8.108039855957031	step loss embedding	1.8016157058569102
1717	step loss nmt	7.9318437576293945	step loss embedding	2.0087896271755823
1718	step loss nmt	7.911164283752441	step loss embedding	1.9647653500239055
1719	step loss nmt	7.6989850997924805	step loss embedding	1.8483910759290059
1720	step loss nmt	8.285387992858887	step loss embedding	1.82847398519516
1721	step loss nmt	8.016563415527344	step loss embedding	2.098351175134832
1722	step loss nmt	7.845556735992432	step loss embedding	1.962427298227946
1723	step loss nmt	7.990872383117676	step loss embedding	1.9445192416508992
1724	step loss nmt	8.416983604431152	step loss embedding	1.9784597605466843
1725	step loss nmt	7.711630344390869	step loss embedding	2.1597232553

1819	step loss nmt	8.084994316101074	step loss embedding	2.074956867429945
1820	step loss nmt	8.500028610229492	step loss embedding	1.964247226715088
1821	step loss nmt	7.8025407791137695	step loss embedding	1.9677204026116266
1822	step loss nmt	7.775354862213135	step loss embedding	1.950178821881612
1823	step loss nmt	8.00097942352295	step loss embedding	1.9250901937484741
1824	step loss nmt	7.711450576782227	step loss embedding	1.8696451425552367
1825	step loss nmt	7.861330509185791	step loss embedding	1.9806032289158215
1826	step loss nmt	7.9644365310668945	step loss embedding	1.9725486993789674
1827	step loss nmt	8.379058837890625	step loss embedding	1.9553267359733582
1828	step loss nmt	8.901688575744629	step loss embedding	1.9988077998161315
1829	step loss nmt	7.9649529457092285	step loss embedding	1.9401243868328275
1830	step loss nmt	8.035148620605469	step loss embedding	2.128335169383458
1831	step loss nmt	7.861604690551758	step loss embedding	2.161456134584215
1832	step loss 

1927	step loss nmt	7.698558330535889	step loss embedding	1.900023102760315
1928	step loss nmt	7.8088297843933105	step loss embedding	1.9404382824897766
1929	step loss nmt	6.843717575073242	step loss embedding	1.925007505294604
1930	step loss nmt	7.776052951812744	step loss embedding	2.043723084709861
1931	step loss nmt	7.599963665008545	step loss embedding	2.0054979854159884
1932	step loss nmt	7.124505519866943	step loss embedding	1.991350921491782
36	0.0005	step loss nmt	8.530068987709713	step loss embedding	1.5912795961172728	checkpoint info
1933	step loss nmt	7.872013092041016	step loss embedding	2.135985881090164
1934	step loss nmt	7.806035041809082	step loss embedding	1.9499870075119867
1935	step loss nmt	7.989771842956543	step loss embedding	2.200360417366028
1936	step loss nmt	8.065712928771973	step loss embedding	2.135364214579264
1937	step loss nmt	7.177825450897217	step loss embedding	2.0491026043891907
1938	step loss nmt	7.685477256774902	step loss embedding	1.96337714791297

2032	step loss nmt	7.46804666519165	step loss embedding	1.9974570751190186
2033	step loss nmt	7.539393424987793	step loss embedding	2.061088032192654
2034	step loss nmt	8.327535629272461	step loss embedding	2.006659197807312
2035	step loss nmt	7.602295875549316	step loss embedding	2.334413014925443
2036	step loss nmt	7.789185047149658	step loss embedding	1.9281796142458916
2037	step loss nmt	7.547398090362549	step loss embedding	2.0157849490642548
2038	step loss nmt	7.94126033782959	step loss embedding	1.9935037331147627
2039	step loss nmt	7.589704513549805	step loss embedding	2.1524586379528046
2040	step loss nmt	7.873867034912109	step loss embedding	2.0133362114429474
2041	step loss nmt	8.021665573120117	step loss embedding	2.0729631980260215
2042	step loss nmt	7.653676986694336	step loss embedding	2.0091195901234946
2043	step loss nmt	7.985166072845459	step loss embedding	2.248969619924372
2044	step loss nmt	7.770963191986084	step loss embedding	1.961798995733261
2045	step loss nmt	

2139	step loss nmt	7.596896171569824	step loss embedding	2.197924256324768
2140	step loss nmt	7.286527156829834	step loss embedding	2.013569434483846
2141	step loss nmt	7.766252040863037	step loss embedding	2.088985482851664
2142	step loss nmt	8.197171211242676	step loss embedding	2.0138111764734443
2143	step loss nmt	7.490690231323242	step loss embedding	2.063269903785304
2144	step loss nmt	7.983999729156494	step loss embedding	2.255767128684304
2145	step loss nmt	8.203299522399902	step loss embedding	2.1230635378095837
2146	step loss nmt	8.938161849975586	step loss embedding	2.088732375038995
2147	step loss nmt	8.079096794128418	step loss embedding	2.1108581318574795
2148	step loss nmt	7.821846961975098	step loss embedding	2.1341744899749755
2149	step loss nmt	7.985294818878174	step loss embedding	2.232776509390937
2150	step loss nmt	7.5702948570251465	step loss embedding	2.0736790895462036
2151	step loss nmt	8.000778198242188	step loss embedding	2.194398456149631
2152	step loss nmt	

2246	step loss nmt	7.895325660705566	step loss embedding	2.1716105143229165
2247	step loss nmt	7.7174482345581055	step loss embedding	2.1816953145540676
2248	step loss nmt	8.110840797424316	step loss embedding	2.275731384754181
2249	step loss nmt	7.823238372802734	step loss embedding	2.1154610216617584
2250	step loss nmt	8.021939277648926	step loss embedding	2.083923101425171
2251	step loss nmt	7.833564758300781	step loss embedding	2.1811193943023683
2252	step loss nmt	8.507272720336914	step loss embedding	2.3047703829678623
2253	step loss nmt	8.016098022460938	step loss embedding	2.336564336504255
2254	step loss nmt	8.125518798828125	step loss embedding	2.1881739876487036
2255	step loss nmt	7.952013969421387	step loss embedding	2.065101639429728
2256	step loss nmt	8.4173583984375	step loss embedding	2.234973464693342
2257	step loss nmt	7.625447750091553	step loss embedding	2.185197591781616
2258	step loss nmt	7.747110366821289	step loss embedding	2.239203748703003
2259	step loss nmt	7

2352	step loss nmt	7.4588494300842285	step loss embedding	2.204971262386867
2353	step loss nmt	8.047195434570312	step loss embedding	2.19264946381251
2354	step loss nmt	7.496706962585449	step loss embedding	2.2004595597585044
2355	step loss nmt	5.800051212310791	step loss embedding	2.200084238052368
2356	step loss nmt	7.332099437713623	step loss embedding	2.17044338313016
2357	step loss nmt	8.180510520935059	step loss embedding	2.3041759729385376
2358	step loss nmt	7.879012584686279	step loss embedding	2.1348236448624553
2359	step loss nmt	7.746572017669678	step loss embedding	2.0508896020742564
2360	step loss nmt	8.202258110046387	step loss embedding	2.0793644189834595
2361	step loss nmt	7.7867584228515625	step loss embedding	2.170742188181196
2362	step loss nmt	7.579340934753418	step loss embedding	2.0058558305104572
2363	step loss nmt	7.764938831329346	step loss embedding	2.1259970029195148
2364	step loss nmt	7.571615695953369	step loss embedding	2.2089510122934977
2365	step loss nm

2458	step loss nmt	7.255510330200195	step loss embedding	2.316431482632955
2459	step loss nmt	7.6133036613464355	step loss embedding	2.4622477144002914
2460	step loss nmt	7.717666149139404	step loss embedding	2.178651054700216
2461	step loss nmt	7.9741339683532715	step loss embedding	2.196063533425331
2462	step loss nmt	8.122941017150879	step loss embedding	2.341081460316976
2463	step loss nmt	7.794125080108643	step loss embedding	2.2291197299957277
2464	step loss nmt	7.699281692504883	step loss embedding	2.31863109767437
2465	step loss nmt	8.013415336608887	step loss embedding	2.2258122126261393
2466	step loss nmt	7.613790512084961	step loss embedding	2.209941118955612
2467	step loss nmt	7.749589443206787	step loss embedding	2.4468262412331323
2468	step loss nmt	7.301813125610352	step loss embedding	2.342943640316234
2469	step loss nmt	7.729976177215576	step loss embedding	2.2120742052793503
2470	step loss nmt	8.10204029083252	step loss embedding	2.207317900657654
2471	step loss nmt	7

2564	step loss nmt	8.714007377624512	step loss embedding	2.1982717911402383
2565	step loss nmt	8.331901550292969	step loss embedding	2.2765616178512573
2566	step loss nmt	7.625314712524414	step loss embedding	2.0808005954908286
2567	step loss nmt	7.803764820098877	step loss embedding	2.1872385342915854
2568	step loss nmt	7.675093173980713	step loss embedding	2.2354053258895874
2569	step loss nmt	7.386745452880859	step loss embedding	2.2100587541406806
2570	step loss nmt	7.680424690246582	step loss embedding	2.269737958908081
2571	step loss nmt	7.521056175231934	step loss embedding	2.2602350314458213
2572	step loss nmt	7.922428607940674	step loss embedding	2.348035156726837
2573	step loss nmt	8.061807632446289	step loss embedding	2.332236797913261
2574	step loss nmt	7.83349084854126	step loss embedding	2.1914816796779633
2575	step loss nmt	7.858027458190918	step loss embedding	2.240982549531119
2576	step loss nmt	7.8176727294921875	step loss embedding	2.2862598385129655
2577	step loss n

In [None]:
torch.save(model.state_dict(), f'{load_vocab}/model_weights')

In [None]:
print(SRC.vocab.itos[4880])