In [None]:
import argparse
import pprint

import torch, gc
import torch.nn as nn
from torch import optim

from data_loader import DataLoader
import data_loader
import trainer

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import timeit

from models.transformer import Transformer
import model_util as mu

In [None]:
def get_model(input_size, output_size, 
    hidden_size=32,
    n_splits=8,
    n_layers=4,
    dropout=0.0,
    use_transformer=True):

	if use_transformer:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	else:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	
	return model


def get_crit(output_size, pad_index):
	# Default weight for loss equals to 1, but we don't need to get loss for PAD token
	# Thus, set a weight for PAD to zero.
	loss_weight = torch.ones(output_size)
	loss_weight[pad_index] = 0.

	# Instead of using Cross-Entropy Loss,
	# we can use Negative Log-Likelihood(NLL) Loss with log-probability.
	print('\n Loss function: Negative Log-Likelihood with log-probability (NLLLoss)')
	crit = nn.NLLLoss(
		weight=loss_weight,
		reduction='sum',
	)

	return crit


def get_optimizer(model, 
    use_adam=True,
    use_transformer=True,
    lr=0.0001,):
	if use_adam:
		if use_transformer:
			optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
		else: # case of rnn based seq2seq
			optimizer = optim.Adam(model.parameters(), lr=lr)
	else:
		print('Optimizer: Adam')
		optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
	
	return optimizer

In [None]:
loader = DataLoader(
        'corpus.shuf.train.tok.bpe',
        'corpus.shuf.val.tok.bpe',
        ('en', 'ko'),                           # Source and target language.
        batch_size=8,
        device=-1,                              # Lazy loading
        max_length=50,                          # Loger sequence will be excluded.
        dsl=False,                              # Turn-off Dual-supervised Learning mode.
    )

In [None]:
input_size, output_size = len(loader.src.vocab), len(loader.tgt.vocab)
print('\ninput_size: ', input_size)
print('output_size: ', output_size)

In [None]:
model = get_model(input_size, output_size)
print('\n', model)

In [None]:
crit = get_crit(output_size, data_loader.PAD)

In [None]:
# if model_weight is not None:
    # model.load_state_dict(model_weight)

# check for available gpu
if torch.cuda.is_available():
    device_num = 0
    print('\nUsing device number: 0')
else:
    device_num = -1
    print('\nUsing device number: -1')

# Clear memory cache
gc.collect()
torch.cuda.empty_cache()

# Pass model to GPU device if it is necessary
if device_num >= 0:
    model.cuda(device_num)
    crit.cuda(device_num)

In [None]:
optimizer = get_optimizer(model)

In [None]:
# if opt_weight is not None and config.use_adam:
    # optimizer.load_state_dict(opt_weight)

lr_schedular = None

In [None]:
overall_title = 'local'

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('./tensorboard/'+overall_title+'/tests')

title = overall_title + '_01'

In [None]:
start_time = timeit.default_timer()

trainer.train(
    model,
    crit,
    optimizer,
    train_loader=loader.train_iter,
    valid_loader=loader.valid_iter,
    src_vocab=loader.src.vocab,
    tgt_vocab=loader.tgt.vocab,
    n_epochs=20,
    lr_schedular=lr_schedular,
    writer=writer,
    title=title,
)

end_time = (timeit.default_timer() - start_time) / 60.0

In [None]:
mu.saveModel(overall_title, title, model)
# mu.graphModel(train_dataloader, model, writer, device)

In [None]:
model = mu.getModel(overall_title, title)
print(model)