In [1]:
import argparse
import pprint

import torch, gc
import torch.nn as nn
from torch import optim

from data_loader import DataLoader
import data_loader
import trainer

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import timeit

from models.transformer import Transformer
import model_util as mu

using data from torchtext.legacy


In [2]:
def get_model(input_size, output_size, 
    hidden_size=32,
    n_splits=8,
    n_layers=4,
    dropout=0.2,
    use_transformer=True):

	if use_transformer:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	else:
		model = Transformer(
			input_size,						# Source vocabulary size
			hidden_size,				# Transformer doesn't need word_vec_size,
			output_size,					# Target vocabulary size
			n_splits=n_splits,		# Number of head in Multi-head Attention
			n_enc_blocks=n_layers,	# number of encoder blocks
			n_dec_blocks=n_layers,	# Number of decoder blocks
			dropout_p=dropout,		# Dropout rate on each block
		)
	
	return model


def get_crit(output_size, pad_index):
	# Default weight for loss equals to 1, but we don't need to get loss for PAD token
	# Thus, set a weight for PAD to zero.
	loss_weight = torch.ones(output_size)
	loss_weight[pad_index] = 0.0

	# Instead of using Cross-Entropy Loss,
	# we can use Negative Log-Likelihood(NLL) Loss with log-probability.
	print('\n Loss function: Negative Log-Likelihood with log-probability (NLLLoss)')
	crit = nn.NLLLoss(
		weight=loss_weight,
		reduction='sum',
	)

	return crit


def get_optimizer(model, 
    use_adam=True,
    use_transformer=True,
    lr=0.0001,):
	if use_adam:
		if use_transformer:
			optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
		else: # case of rnn based seq2seq
			optimizer = optim.Adam(model.parameters(), lr=lr)
	else:
		print('Optimizer: Adam')
		optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.9, .98))
	
	return optimizer

In [None]:
batch_size = 32
dropout = 0.0
hidden_size = 64
lang = ('en', 'ko')
lr = 0.0003
max_length = 20
n_epochs = 10
n_layers = 4
n_splits = 8
research_num = '01'
research_subject = 'merged_update'
test_fn = 'corpus.shuf.test.tok.bpe'
train_fn = 'corpus.shuf.train.tok.bpe'
valid_fn = 'corpus.shuf.valid.tok.bpe'

In [3]:
loader = DataLoader(
        train_fn=train_fn,
        valid_fn=valid_fn,
        test_fn=test_fn,
        exts=lang,
        batch_size=batch_size,
        device=-1,                                      # Lazy loading
        max_length=max_length,                          # Loger sequence will be excluded.
        dsl=False,                                      # Turn-off Dual-supervised Learning mode.
    )

In [4]:
input_size, output_size = len(loader.src.vocab), len(loader.tgt.vocab)
print('\ninput_size: ', input_size)
print('output_size: ', output_size)


input_size:  15884
output_size:  26204


In [None]:
model = get_model(input_size, output_size,
    hidden_size=hidden_size,
    n_splits=n_splits,
    n_layers=n_layers,
    dropout=dropout,
    use_transformer=True)

In [5]:
crit = get_crit(output_size, data_loader.PAD)


 Loss function: Negative Log-Likelihood with log-probability (NLLLoss)


In [None]:
# if model_weight is not None:
    # model.load_state_dict(model_weight)

# check for available gpu
if torch.cuda.is_available():
    device_num = 0
    print('\nUsing device number: 0')
else:
    device_num = -1
    print('\nUsing device number: -1')

# Clear memory cache
gc.collect()
torch.cuda.empty_cache()

# Pass model to GPU device if it is necessary
if device_num >= 0:
    model.cuda(device_num)
    crit.cuda(device_num)

In [None]:
optimizer = get_optimizer(model, lr=lr)

In [6]:
# if opt_weight is not None and config.use_adam:
    # optimizer.load_state_dict(opt_weight)

lr_schedular = None

In [7]:
overall_title = 'local1'

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('./tensorboard/'+overall_title+'/tests')

title = overall_title + '_08'

In [None]:
start_time = timeit.default_timer()

trainer.train(
    model,
    crit,
    optimizer,
    train_loader=loader.train_iter,
    valid_loader=loader.valid_iter,
    src_vocab=loader.src.vocab,
    tgt_vocab=loader.tgt.vocab,
    n_epochs=20,
    lr_schedular=lr_schedular,
    writer=writer,
    title=title,
)

end_time = (timeit.default_timer() - start_time) / 60.0

In [None]:
end_time

In [None]:
mu.saveModel(overall_title, title, model)
# mu.graphModel(train_dataloader, model, writer, device)

In [8]:
model = mu.getModel(overall_title, title)
print(model)

Transformer(
  (emb_enc): Embedding(15884, 256)
  (emb_dec): Embedding(26204, 256)
  (emb_dropout): Dropout(p=0.2, inplace=False)
  (encoder): MySequential(
    (0): EncoderBlock(
      (attn): MultiHead(
        (Q_linear): Linear(in_features=256, out_features=256, bias=False)
        (K_linear): Linear(in_features=256, out_features=256, bias=False)
        (V_linear): Linear(in_features=256, out_features=256, bias=False)
        (linear): Linear(in_features=256, out_features=256, bias=False)
        (attn): Attention(
          (softmax): Softmax(dim=-1)
        )
      )
      (attn_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn_dropout): Dropout(p=0.2, inplace=False)
      (fc): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
      (fc_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (fc_dropout): Dropout(p=0.2, in

In [9]:
import tester

In [10]:
loss, acc = tester.test(
    model,
    crit,
    test_loader=loader.test_iter,
    src_vocab=loader.src.vocab,
    tgt_vocab=loader.tgt.vocab,
    lr_schedular=lr_schedular,
)


Using device number: 0


In [11]:
print(loss)
print(acc)

2.914247512817383
47.109375
