In [1]:
import argparse
import time
import torch
print(torch.__version__)
from Models import get_model
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import dill as pickle
import time
import math
import numpy as np
import multiprocessing as mp
import random
import string
import sys
import os
import whoosh, glob, time, pickle
import whoosh.fields as wf
from whoosh.qparser import QueryParser
from whoosh import index
import threading
from whoosh import filedb
from whoosh.filedb.filestore import FileStorage
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import time
from InMemorySearch import *
from client import *
from pathlib import Path

1.0.0


In [2]:
#At first checking our super clients
query_list = ["european crime records", "crime records", "european crime", "european records", "crime crimes violent sheriff enforcement re criminals stresak bill strikes", "LA times corpus", "crime records"]    
super_client = SuperClient()
print(super_client.hosts)
final_result = super_client.query_expansion_distributed(query_list)
print(len(final_result))

['10.141.0.104', '10.141.0.146', '10.141.0.134', '10.141.0.120', '10.141.0.121']
7


In [3]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


In [4]:
parser = argparse.ArgumentParser()
#If we are working on small dataset
small = 0
#If we want to activate relevance based training
relevance_training = 0
    
if small==1:
    parser.add_argument('-src_data', type=str, default='data/italian_small.txt')
    parser.add_argument('-trg_data', type=str, default='data/english_small.txt')
    parser.add_argument('-trg_data_retrieval', type=str, default='data/english_retrieval.txt')

else:
    parser.add_argument('-src_data', type=str, default='data/italian.txt')
    parser.add_argument('-trg_data', type=str, default='data/english.txt')  
    parser.add_argument('-trg_data_retrieval', type=str, default='data/LATIMESTEXT2.txt')

parser.add_argument('-src_lang', type=str, default='it')
parser.add_argument('-trg_lang', type=str, default='en')
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-SGDR', action='store_true')
parser.add_argument('-epochs', type=int, default=1)
parser.add_argument('-d_model', type=int, default=200)
parser.add_argument('-n_layers', type=int, default=6)
parser.add_argument('-heads', type=int, default=8)
parser.add_argument('-dropout', type=int, default=0.1)
parser.add_argument('-batchsize', type=int, default=1000)
parser.add_argument('-printevery', type=int, default=50)
parser.add_argument('-load_vocab', type=str, default='clir_it_en')

if relevance_training == 1:
    my_file = Path("weights/model_weights")
    if my_file.is_file():
        parser.add_argument('-load_weights', type=str, default='weights')
    else:
        parser.add_argument('-load_weights', type=str, default=None)
    parser.add_argument('-lr', type=int, default=0.01)
else: 
    my_file = Path("weights/model_weights")
    if my_file.is_file():
        parser.add_argument('-load_weights', type=str, default='weights')
    else:
        parser.add_argument('-load_weights', type=str, default=None)         
    parser.add_argument('-lr', type=int, default=0.0001)
    
parser.add_argument('-create_valset', action='store_true')
parser.add_argument('-max_strlen', type=int, default=80)
parser.add_argument('-floyd', action='store_true')
parser.add_argument('-checkpoint', type=int, default=5)

opt = parser.parse_args(args=[])


In [5]:
def tokenizer(text):  # create a tokenizer function
    return text.split()
    
def create_fields(opt):    
    print("loading tokenizers...") 
    TRG = data.Field(lower=True, tokenize=tokenizer, init_token='<sos>', eos_token='<eos>')
    SRC = data.Field(lower=True, tokenize=tokenizer)   
    SRC = pickle.load(open(f'{opt.load_vocab}/SRC.pkl', 'rb'))
    TRG = pickle.load(open(f'{opt.load_vocab}/TRG.pkl', 'rb'))
    return(SRC, TRG)

#this function will consider both europarl and CLEF
def create_dataset(opt, SRC, TRG):
    print("creating dataset and iterator... ")
    
    raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]}
    df = pd.DataFrame(raw_data, columns=["src", "trg"])

    mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)

    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)

    train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=batch_size_fn, train=True, shuffle=True)
      
    os.remove('translate_transformer_temp.csv')   
    
    if opt.load_vocab is None:
        print("creating dataset for retrieval corpus... ")
        raw_data = {'trg': [line for line in opt.trg_data_retrieval]}
        df = pd.DataFrame(raw_data, columns=["trg"])
        mask = (df['trg'].str.count(' ') > 1)
        df = df.loc[mask]
        df.to_csv("translate_transformer_retrieval_temp.csv", index=False)
        data_fields = [('trg', TRG)]
        train_retrieval = data.TabularDataset('./translate_transformer_retrieval_temp.csv', format='csv', fields=data_fields)
        os.remove('translate_transformer_retrieval_temp.csv')    

        print("building vocabulary for both europarl and retrieval corpus")

        SRC.build_vocab(train)
        TRG.build_vocab(train, train_retrieval)

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']
    opt.train_len = get_len(train_iter)
    return train_iter



In [6]:
opt.device = 0 if opt.no_cuda is False else -1
if opt.device == 0:
    assert torch.cuda.is_available()
print(opt.device)
read_data(opt)

SRC, TRG = create_fields(opt)
opt.train = create_dataset(opt, SRC, TRG)
dst = ''
#TRG = create_retrieval_vocabulary(opt, TRG)
model = get_model(opt, len(SRC.vocab), len(TRG.vocab))

opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)

if opt.SGDR == True:
    opt.sched = CosineWithRestarts(opt.optimizer, T_max=opt.train_len)

if opt.checkpoint > 0:
    print("model weights will be saved every %d minutes and at end of epoch to directory weights/"%(opt.checkpoint))
    
if opt.load_vocab is None:
    pickle.dump(SRC, open(f'{dst}/SRC.pkl', 'wb'))
    pickle.dump(TRG, open(f'{dst}/TRG.pkl', 'wb'))

0
loading tokenizers...
creating dataset and iterator... 


The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


loading pretrained weights...
model weights will be saved every 5 minutes and at end of epoch to directory weights/


In [7]:
super_client = SuperClient()

In [8]:
def train_model(model, opt, SRC, TRG):
    torch.cuda.empty_cache()
    inmem = WhooshInMemorySearch()
    print("training model...")
    model.train()
    start = time.time()
    if opt.checkpoint > 0:
        cptime = time.time()
                    
    for epoch in range(opt.epochs):
        total_loss = 0
        if opt.floyd is False:
            print("   %dm: epoch %d [%s]  %d%%  loss = %s" %\
            ((time.time() - start)//60, epoch + 1, "".join(' '*20), 0, '...'), end='\r')
        
        if opt.checkpoint > 0:
            torch.save(model.state_dict(), 'weights/model_weights')
                    
        for i, batch in enumerate(opt.train):
            torch.cuda.empty_cache()
            start = time.clock()
            src = batch.src.transpose(0,1) # src_size = (187, 4)
            trg = batch.trg.transpose(0,1) # trg_size = (187, 8)             
            if relevance_training == 1:
                trg_strings = [' '.join([TRG.vocab.itos[ind] for ind in ex]) for ex in trg]
                trg_strings_rm = super_client.query_expansion_distributed(trg_strings)
                
                trg_strings_rm_id = torch.LongTensor([[TRG.vocab.stoi[token] for token in sentence] for sentence in trg_strings_rm]).cuda()            
                
                embed_trg = model.decoder.embed(trg_strings_rm_id)
                #old loss
                #embed_trg = embed_trg.view(embed_trg.size(0), -1)
                #new loss 
                embed_trg = embed_trg.sum(1)
            
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, opt)            
            src_mask = src_mask.cuda()            
            trg_mask = trg_mask.cuda()
            src = src.cuda() 
            trg_input = trg_input.cuda()                                  
            preds = model(src, trg_input, src_mask, trg_mask)
            
            #the goal is to find the embedding of the predictions 
            if relevance_training == 1: 
                out = F.softmax(preds, dim=-1)
                probs, ix = out[:, :].data.topk(1)
                preds_token_ids = ix.view(ix.size(0), -1)
                #new loss function 
                embed_pred = model.decoder.embed(preds_token_ids)
                embed_pred = embed_pred.sum(1)
                #print("embed preds size " + str(embed_pred_sum.size()))            

                #old loss function
#                 pred_strings = [' '.join([TRG.vocab.itos[ind] for ind in ex]) for ex in preds_token_ids]
#                 pred_strings_rm = super_client.query_expansion_distributed(pred_strings)                                
#                 pred_strings_rm_id = torch.LongTensor([[TRG.vocab.stoi[token] for token in sentence] for sentence in pred_strings_rm]).cuda()            
#                 embed_pred = model.decoder.embed(pred_strings_rm_id)
#                 embed_pred = embed_pred.view(embed_pred.size(0), -1)            
                     
            
            ys = trg[:, 1:].contiguous().view(-1).cuda()
            
            opt.optimizer.zero_grad()
            
            if relevance_training==1:
                #loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad).add(
                #((embed_trg.cuda() - embed_pred.cuda()) **2).mean())
                loss1 = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)              
                loss2 = ((embed_trg.cuda() - embed_pred.cuda()) **2).mean()
                print("batch loss nmt\t" + str(loss1.item()) + "\tbatch loss relevance\t" + str(loss2.item()))
                loss1.backward(retain_graph=True)
                loss2.backward(retain_graph=True)               
            else:
                loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)              
                loss.backward()             
            
            opt.optimizer.step()
                
            if opt.SGDR == True: 
                opt.sched.step()
            total_loss += loss1.item() + loss2.item()
            
            if (i + 1) % opt.printevery == 0:
                 p = int(100 * (i + 1) / opt.train_len)
                 avg_loss = total_loss/opt.printevery
                 if opt.floyd is False:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='\r')
                 else:
                    print("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %\
                    ((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss))
                 total_loss = 0
            
            if opt.checkpoint > 0 and ((time.time()-cptime)//60) // opt.checkpoint >= 1:
                torch.save(model.state_dict(), 'weights/model_weights')
                cptime = time.time()
                        
        print("%dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_loss))


In [9]:
#torch.cuda.empty_cache()
relevance_training = 1
model = model.cuda()
train_model(model, opt, SRC, TRG)

training model...
batch loss nmt	3.3910601139068604	batch loss relevance	0.001179350889287889
batch loss nmt	3.0724680423736572	batch loss relevance	0.0008094001677818596
batch loss nmt	4.594939708709717	batch loss relevance	0.0011161633301526308
batch loss nmt	5.234333515167236	batch loss relevance	0.0010094906901940703
batch loss nmt	3.8060078620910645	batch loss relevance	0.0011968103935942054
batch loss nmt	3.247288942337036	batch loss relevance	0.0012827636674046516
batch loss nmt	3.505147933959961	batch loss relevance	0.0011120610870420933
batch loss nmt	4.531473159790039	batch loss relevance	0.0006579987821169198
batch loss nmt	3.927861213684082	batch loss relevance	0.001077174092642963
batch loss nmt	3.511997938156128	batch loss relevance	0.0009162729838863015
batch loss nmt	3.5104730129241943	batch loss relevance	0.0010013490682467818
batch loss nmt	3.3216309547424316	batch loss relevance	0.0011692022671923041
batch loss nmt	3.335780143737793	batch loss relevance	0.00129651650

batch loss nmt	3.77309250831604	batch loss relevance	0.0010114996694028378
batch loss nmt	4.187644004821777	batch loss relevance	0.001288336468860507
batch loss nmt	3.607799768447876	batch loss relevance	0.0012973634293302894
batch loss nmt	3.7129287719726562	batch loss relevance	0.001068801968358457
batch loss nmt	3.99247670173645	batch loss relevance	0.001211467431858182
batch loss nmt	3.52329421043396	batch loss relevance	0.0009779214160516858
batch loss nmt	4.660335063934326	batch loss relevance	0.0010085196699947119
batch loss nmt	3.7223405838012695	batch loss relevance	0.0010498896008357406
batch loss nmt	3.7080397605895996	batch loss relevance	0.0009707325953058898
batch loss nmt	2.552656650543213	batch loss relevance	0.000947442022152245
batch loss nmt	3.2609405517578125	batch loss relevance	0.0011436794884502888
batch loss nmt	3.8715507984161377	batch loss relevance	0.0012464092578738928
batch loss nmt	3.2597594261169434	batch loss relevance	0.001089209457859397
batch loss nmt

batch loss nmt	3.916616439819336	batch loss relevance	0.0010815797140821815
batch loss nmt	2.956608533859253	batch loss relevance	0.001271490822546184
batch loss nmt	3.2485358715057373	batch loss relevance	0.0009092307300306857
batch loss nmt	5.3271636962890625	batch loss relevance	0.0008951313211582601
batch loss nmt	3.921205997467041	batch loss relevance	0.001148649607785046
batch loss nmt	4.175078868865967	batch loss relevance	0.001282929559238255
batch loss nmt	3.473825693130493	batch loss relevance	0.000978829455561936
batch loss nmt	3.271925210952759	batch loss relevance	0.0011589991627261043
batch loss nmt	3.895723819732666	batch loss relevance	0.001215469092130661
batch loss nmt	3.3283004760742188	batch loss relevance	0.0010154631454497576
batch loss nmt	2.3892605304718018	batch loss relevance	0.0007606518338434398
batch loss nmt	3.9296040534973145	batch loss relevance	0.0010820127790793777
batch loss nmt	3.375784397125244	batch loss relevance	0.0009171876590698957
batch loss n

batch loss nmt	3.434131383895874	batch loss relevance	0.001101725036278367
batch loss nmt	3.371295213699341	batch loss relevance	0.0010491220746189356
batch loss nmt	3.0639970302581787	batch loss relevance	0.0007825753418728709
batch loss nmt	3.508134603500366	batch loss relevance	0.0008919449173845351
batch loss nmt	3.9266021251678467	batch loss relevance	0.0010746612679213285
batch loss nmt	3.4207589626312256	batch loss relevance	0.0011934656649827957
batch loss nmt	2.602370500564575	batch loss relevance	0.0007537811761721969
batch loss nmt	3.696828603744507	batch loss relevance	0.0012311340542510152
batch loss nmt	3.4184460639953613	batch loss relevance	0.0011694120476022363
batch loss nmt	4.117323398590088	batch loss relevance	0.0012014230014756322
batch loss nmt	3.443277359008789	batch loss relevance	0.0010439546313136816
batch loss nmt	3.0759878158569336	batch loss relevance	0.0011171289952471852
batch loss nmt	4.273647308349609	batch loss relevance	0.001235686126165092
batch los

batch loss nmt	3.3501813411712646	batch loss relevance	0.0010094150202348828
batch loss nmt	3.655226707458496	batch loss relevance	0.0011737587628886104
batch loss nmt	4.398533821105957	batch loss relevance	0.001104828086681664
batch loss nmt	3.6219351291656494	batch loss relevance	0.0010510438587516546
batch loss nmt	4.000452518463135	batch loss relevance	0.001125519396737218
batch loss nmt	3.450496196746826	batch loss relevance	0.0011298557510599494
batch loss nmt	4.014549732208252	batch loss relevance	0.0009687395649962127
batch loss nmt	3.1225223541259766	batch loss relevance	0.0011117263929918408
batch loss nmt	2.9937846660614014	batch loss relevance	0.0009498040308244526
batch loss nmt	3.500291585922241	batch loss relevance	0.0010741642909124494
batch loss nmt	3.617668628692627	batch loss relevance	0.0008513002539984882
batch loss nmt	3.2518739700317383	batch loss relevance	0.000814013066701591
batch loss nmt	3.4443349838256836	batch loss relevance	0.0008819758077152073
batch los

batch loss nmt	3.2850611209869385	batch loss relevance	0.0008649654919281602
batch loss nmt	4.749572277069092	batch loss relevance	0.00113286345731467
batch loss nmt	3.3607540130615234	batch loss relevance	0.0010453517315909266
batch loss nmt	3.341212511062622	batch loss relevance	0.0011175594991073012
batch loss nmt	3.3006882667541504	batch loss relevance	0.0010221737902611494
batch loss nmt	3.2338407039642334	batch loss relevance	0.0012224962702021003
batch loss nmt	3.5007705688476562	batch loss relevance	0.000978245516307652
batch loss nmt	3.441756248474121	batch loss relevance	0.0010811160318553448
batch loss nmt	3.258087158203125	batch loss relevance	0.0010499297641217709
batch loss nmt	3.938767433166504	batch loss relevance	0.0010282793082296848
batch loss nmt	2.9457530975341797	batch loss relevance	0.0007913286099210382
batch loss nmt	3.592405080795288	batch loss relevance	0.0009995202999562025
batch loss nmt	3.6032285690307617	batch loss relevance	0.001006580307148397
batch los

batch loss nmt	3.8433010578155518	batch loss relevance	0.001036760164424777
batch loss nmt	3.221762180328369	batch loss relevance	0.0010226966114714742
batch loss nmt	4.306225299835205	batch loss relevance	0.0009579447214491665
batch loss nmt	3.714390277862549	batch loss relevance	0.0009655322646722198
batch loss nmt	4.9110307693481445	batch loss relevance	0.0015229651471599936
batch loss nmt	3.4438843727111816	batch loss relevance	0.0010113917523995042
batch loss nmt	3.3890650272369385	batch loss relevance	0.0010376219870522618
batch loss nmt	3.846619129180908	batch loss relevance	0.001236703828908503
batch loss nmt	3.359321117401123	batch loss relevance	0.0011988348560407758
batch loss nmt	3.3098552227020264	batch loss relevance	0.001184906461276114
batch loss nmt	3.334054946899414	batch loss relevance	0.0011431622551754117
batch loss nmt	3.9990713596343994	batch loss relevance	0.0011410469887778163
batch loss nmt	3.4776744842529297	batch loss relevance	0.001178214093670249
batch los

batch loss nmt	4.087344646453857	batch loss relevance	0.0013444494688883424
batch loss nmt	3.4526517391204834	batch loss relevance	0.0011506406590342522
batch loss nmt	3.597912073135376	batch loss relevance	0.0012453872477635741
batch loss nmt	3.5420777797698975	batch loss relevance	0.0009873275412246585
batch loss nmt	3.442775011062622	batch loss relevance	0.001113839796744287
batch loss nmt	3.280423879623413	batch loss relevance	0.0010476326569914818
batch loss nmt	5.552959442138672	batch loss relevance	0.0010595674393698573
batch loss nmt	3.81675124168396	batch loss relevance	0.0008955063531175256
batch loss nmt	3.235121488571167	batch loss relevance	0.0008146936306729913
batch loss nmt	2.7223312854766846	batch loss relevance	0.0007770402007736266
batch loss nmt	3.1471896171569824	batch loss relevance	0.0010273174848407507
batch loss nmt	2.9530186653137207	batch loss relevance	0.0008648690418340266
batch loss nmt	5.730668067932129	batch loss relevance	0.0010244554141536355
batch los

Process Process-8589:
Process Process-8587:
Process Process-8590:
Process Process-8586:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/smsarwar/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/smsarwar/anaconda3/lib

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), f'{load_vocab}/model_weights')

# x = torch.randn(4,6,1)
# print (x)
# x = x.view(4, -1)
# print (x)
# # print(x)
# # probs, ix = x[:, :].data.topk(1)
# # print(probs)
# # print(ix)

# x = torch.randn(3,2)
# print(x)
# y = torch.randn(3,2)
# print(y)
# print (((x - y)**2).mean())
# #F.cosine_embedding_loss(x, y, reduction='mean')

In [None]:
print(SRC.vocab.itos[4880])