In [1]:
import torch
from torch import optim
from functools import partial
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import sys
import os
import torch.nn as nn

In [2]:
path_to_helper_files = os.path.join('..', 'py_files')

In [3]:
sys.path.append(path_to_helper_files)

In [4]:
import global_variables
import dataset_helper
import nnet_models
import train_utilities

device = global_variables.device;

In [5]:
MAX_LEN = 48
batchSize = 128

source_embed_dim = 512
source_hidden_size = 512
source_rnn_layers = 2
source_rnn_type  = 'lstm'

target_embed_dim= 512
target_hidden_size = 1024
target_rnn_layers = 2

attention = True


In [6]:
en_train_path = '../Data/iwslt-vi-en/train.tok.en'
vi_train_path = '../Data/iwslt-vi-en/train.tok.vi'

en_val_path = '../Data/iwslt-vi-en/dev.tok.en'
vi_val_path = '../Data/iwslt-vi-en/dev.tok.vi'

In [7]:
saved_language_model_dir = os.path.join('..', 'lang_obj')

In [8]:
dataset_dict = {'train': dataset_helper.LanguagePair(source_name = 'vi', target_name='en', 
                                                    source_path = vi_train_path, target_path = en_train_path, 
                                                    lang_obj_path = saved_language_model_dir ), 
               'val': dataset_helper.LanguagePair(source_name = 'vi', target_name='en', 
                                                    source_path = vi_val_path, target_path = en_val_path, 
                                                    lang_obj_path = saved_language_model_dir, val = True ) }

In [9]:
dataloader_dict = {'train': DataLoader(dataset_dict['train'], batch_size = batchSize, 
                                    collate_fn = partial(dataset_helper.vocab_collate_func, MAX_LEN=MAX_LEN),
                                shuffle = True, num_workers=0), 
                  'val': DataLoader(dataset_dict['val'], batch_size = 1, 
                                    collate_fn = dataset_helper.vocab_collate_func_val,
                                shuffle = True, num_workers=0)}

In [10]:
encoder = nnet_models.EncoderRNN(dataset_dict['train'].source_lang_obj.n_words, 
                                 embed_dim = source_embed_dim, 
                                 hidden_size = source_hidden_size,
                                 rnn_layers = source_rnn_layers, 
                                 rnn_type = source_rnn_type).to(device);
                                 

In [11]:
decoder = nnet_models.AttentionDecoderRNN(dataset_dict['train'].target_lang_obj.n_words, 
                                            embed_dim = target_embed_dim, 
                                            hidden_size = target_hidden_size, 
                                            n_layers = target_rnn_layers, 
                                            attention = attention).to(device)   

In [12]:
# encoder_optimizer = optim.Adam(encoder.parameters(), lr = 7e-5)
# decoder_optimizer = optim.Adam(decoder.parameters(), lr = 7e-5)

# enc_scheduler = ReduceLROnPlateau(encoder_optimizer, min_lr=1e-5,factor = 0.5,  patience=0)

# dec_scheduler = ReduceLROnPlateau(decoder_optimizer, min_lr=1e-5,factor = 0.5,  patience=0)

In [13]:
encoder_optimizer = optim.SGD(encoder.parameters(), lr=0.25,nesterov=True, momentum = 0.99)
enc_scheduler = ReduceLROnPlateau(encoder_optimizer, min_lr=1e-4,  patience=0)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=0.25,nesterov=True, momentum = 0.99)
dec_scheduler = ReduceLROnPlateau(decoder_optimizer, min_lr=1e-4,  patience=0)

In [14]:
criterion = nn.NLLLoss(ignore_index = global_variables.PAD_IDX)

In [15]:
enc, dec, loss_hist, acc_hist = train_utilities.train_model(encoder_optimizer, decoder_optimizer, 
                                            encoder, decoder, criterion,
                                            "attention", dataloader_dict, dataset_dict['train'].target_lang_obj, 
                                            num_epochs = 10, rm = 0.95,
                                            enc_scheduler = enc_scheduler, dec_scheduler = dec_scheduler)

epoch 0 train loss = 5.581000069238563, time = 1769.0952169895172
epoch 0 val loss = 6.343554078380956, time = 72.39613056182861
validation BLEU =  4.951975678850327
epoch 1 train loss = 4.009828028602194, time = 1768.7664515972137
epoch 1 val loss = 5.943167093603605, time = 72.17258620262146
validation BLEU =  11.239757751534276
epoch 2 train loss = 3.339941236507774, time = 1767.2253789901733
epoch 2 val loss = 5.616295772084632, time = 71.92485451698303
validation BLEU =  15.861213783504821
epoch 3 train loss = 2.9764972614213963, time = 1767.8095848560333
epoch 3 val loss = 5.558829465146508, time = 71.71112585067749
validation BLEU =  17.68784112220196
epoch 4 train loss = 2.740610705207905, time = 1768.2375764846802
epoch 4 val loss = 5.508859243808015, time = 71.44663572311401
validation BLEU =  19.05345164055228
epoch 5 train loss = 2.608559530072781, time = 1764.1676487922668
epoch 5 val loss = 5.379563514583918, time = 71.13837504386902
validation BLEU =  18.883192254789833


In [16]:
torch.save(enc.state_dict(), 'encoder_vi_to_eng_sgd_10.pth')

In [17]:
torch.save(dec.state_dict(), 'decoder_vi_to_eng_sgd_10.pth')

In [18]:
encoder_optimizer = optim.Adam(encoder.parameters(), lr = 3e-4)
decoder_optimizer = optim.Adam(decoder.parameters(), lr = 3e-4)

enc_scheduler = ReduceLROnPlateau(encoder_optimizer, min_lr=1e-5,factor = 0.5,  patience=0)
dec_scheduler = ReduceLROnPlateau(decoder_optimizer, min_lr=1e-5,factor = 0.5,  patience=0)

In [19]:
enc, dec, loss_hist, acc_hist = train_utilities.train_model(encoder_optimizer, decoder_optimizer, 
                                            encoder, decoder, criterion,
                                            "attention", dataloader_dict, dataset_dict['train'].target_lang_obj, 
                                            num_epochs = 10, rm = 0.95,
                                            enc_scheduler = enc_scheduler, dec_scheduler = dec_scheduler)

epoch 0 train loss = 2.15302320082086, time = 1762.218641281128
epoch 0 val loss = 5.616881485766958, time = 71.21776175498962
validation BLEU =  24.26047395236699
epoch 1 train loss = 1.9873490552324429, time = 1763.511214017868
epoch 1 val loss = 5.57612200590491, time = 71.29489159584045
validation BLEU =  22.75959101665562
epoch 2 train loss = 1.8730450260824139, time = 1760.0679602622986
epoch 2 val loss = 5.7860374703815305, time = 71.2922670841217
validation BLEU =  24.50001153354797
epoch 3 train loss = 1.6666425151214643, time = 1763.1278405189514
epoch 3 val loss = 5.753166545247101, time = 71.59293985366821
validation BLEU =  24.27148288802906
epoch 4 train loss = 1.604448801665523, time = 1763.908590555191
epoch 4 val loss = 5.701294795803914, time = 71.21982789039612
validation BLEU =  23.548231030750493
epoch 5 train loss = 1.4770591713760801, time = 1764.2957026958466
epoch 5 val loss = 5.801347525379024, time = 71.28975677490234
validation BLEU =  23.742872031083518
epo

In [20]:
torch.save(enc.state_dict(), 'encoder_vi_to_eng_adam_10.pth')
torch.save(dec.state_dict(), 'decoder_vi_to_eng_adam_10.pth')