In [1]:
import argparse
import pprint

import torch, gc
import torch.nn as nn
from torch import optim

from data_loader import DataLoader
import data_loader
import trainer
import tester
from models.transformer import Transformer
import model_util as mu

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import timeit

using data from torchtext.legacy


In [2]:
def get_model(input_size, output_size, 
    hidden_size=32,
    n_splits=8,
    n_layers=4,
    dropout=0.2,
    use_transformer=True):

	if use_transformer:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	else:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	
	return model


def get_crit(output_size, pad_index):
	# Default weight for loss equals to 1, but we don't need to get loss for PAD token
	# Thus, set a weight for PAD to zero.
	loss_weight = torch.ones(output_size)
	loss_weight[pad_index] = 0.0

	# Instead of using Cross-Entropy Loss,
	# we can use Negative Log-Likelihood(NLL) Loss with log-probability.
	print('\n Loss function: Negative Log-Likelihood with log-probability (NLLLoss)')
	crit = nn.NLLLoss(
		weight=loss_weight,
		reduction='sum',
	)

	return crit


def get_optimizer(model, 
    use_adam=True,
    use_transformer=True,
    lr=0.0001,):
	if use_adam:
		if use_transformer:
			optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
		else: # case of rnn based seq2seq
			optimizer = optim.Adam(model.parameters(), lr=lr)
	else:
		print('Optimizer: Adam')
		optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
	
	return optimizer

In [3]:
batch_size = 64
dropout = 0.0
hidden_size = 128
lang = ('en', 'ko')
lr = 0.0003
max_length = 20
n_epochs = 50
n_layers = 4
n_splits = 8
research_num = '01'
research_subject = 'local_medium1'
test_fn = 'corpus.shuf.test.tok.bpe'
train_fn = 'corpus.shuf.train.tok.bpe'
valid_fn = 'corpus.shuf.valid.tok.bpe'

In [4]:
loader = DataLoader(
        train_fn=train_fn,
        valid_fn=valid_fn,
        test_fn=test_fn,
        exts=lang,
        batch_size=batch_size,
        device=-1,                                      # Lazy loading
        max_length=max_length,                          # Loger sequence will be excluded.
        dsl=False,                                      # Turn-off Dual-supervised Learning mode.
    )

In [5]:
input_size, output_size = len(loader.src.vocab), len(loader.tgt.vocab)
print('\ninput_size: ', input_size)
print('output_size: ', output_size)


input_size:  69459
output_size:  154233


In [6]:
model = get_model(input_size, output_size,
    hidden_size=hidden_size,
    n_splits=n_splits,
    n_layers=n_layers,
    dropout=dropout,
    use_transformer=True)

In [7]:
crit = get_crit(output_size, data_loader.PAD)


 Loss function: Negative Log-Likelihood with log-probability (NLLLoss)


In [8]:
# if model_weight is not None:
    # model.load_state_dict(model_weight)

# check for available gpu
if torch.cuda.is_available():
    device_num = 0
    print('\nUsing device number: 0')
else:
    device_num = -1
    print('\nUsing device number: -1')

# Clear memory cache
gc.collect()
torch.cuda.empty_cache()

# Pass model to GPU device if it is necessary
if device_num >= 0:
    model.cuda(device_num)
    crit.cuda(device_num)


Using device number: 0


In [9]:
optimizer = get_optimizer(model, lr=lr)

In [10]:
# if opt_weight is not None and config.use_adam:
    # optimizer.load_state_dict(opt_weight)

lr_schedular = None

In [11]:
subject_title = research_subject
title = subject_title + '_' + research_num

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('../tensorboard/'+subject_title+'/tests')

In [12]:
start_time = timeit.default_timer()

trainer.train(
    model,
    crit,
    optimizer,
    train_loader=loader.train_iter,
    valid_loader=loader.valid_iter,
    src_vocab=loader.src.vocab,
    tgt_vocab=loader.tgt.vocab,
    n_epochs=n_epochs,
    lr_schedular=lr_schedular,
    writer=writer,
    title=title,
)

end_time = (timeit.default_timer() - start_time) / 60.0

Start training...
 Epoch  |  Train Loss  | Train Acc  | Val Loss | Val Acc | Elapsed
--------------------------------------------------------------------------------
   1    |   6.657451   | 39.054159  | 4.462898 | 34.38  | 53.11 
   2    |   3.889955   | 61.082702  | 3.308266 | 40.39  | 53.34 
   3    |   3.057791   | 68.476278  | 2.771558 | 43.05  | 53.33 
   4    |   2.636499   | 72.185069  | 2.464527 | 44.22  | 52.74 
   5    |   2.400443   | 74.336833  | 2.337725 | 45.08  | 52.88 
   6    |   2.283756   | 75.533704  | 2.341477 | 45.47  | 53.42 
   7    |   2.241941   | 76.201400  | 2.309805 | 45.47  | 54.02 
   8    |   2.229728   | 76.609265  | 2.298989 | 45.47  | 53.71 
   9    |   2.233934   | 76.819539  | 2.306580 | 45.47  | 52.62 
  10    |   2.246141   | 76.878759  | 2.319358 | 45.55  | 53.42 
  11    |   2.263491   | 76.851923  | 2.335628 | 45.55  | 53.69 
  12    |   2.282192   | 76.785795  | 2.355156 | 45.55  | 54.42 
  13    |   2.302610   | 76.696454  | 2.375609 | 45.47

In [13]:
print('training time taken: ', end_time)

training time taken:  44.12219217093334


In [14]:
mu.saveModel(subject_title, title, model)
# mu.graphModel(train_dataloader, model, writer, device)

In [15]:
model = mu.getModel(subject_title, title)

In [16]:
test_loss, test_acc = tester.test(
    model,
    crit,
    test_loader=loader.test_iter,
    src_vocab=loader.src.vocab,
    tgt_vocab=loader.tgt.vocab,
    lr_schedular=lr_schedular,
)


Using device number: 0


In [17]:
print('test loss: ', test_loss)
print('test_acc: ', test_acc)

test loss:  3.044724225997925
test_acc:  47.17261904761905
