In [13]:
'''
	Transformer for neuron machine translation
	Ref: https://andrewpeng.dev/transformer-pytorch/
	Ref: https://spaces.ac.cn/archives/6933
	Ref: https://github.com/graykode/nlp-tutorial

'''

import torch
from torchtext.datasets import Multi30k
from torchtext.data import BucketIterator, Field
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

import os
import time
import copy
from torchsummaryX import summary

import model
import utils

## Loading data

In [14]:
SRC = Field(tokenize='spacy',
            tokenizer_language='de',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

TRG = Field(tokenize='spacy',
            tokenizer_language='en',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

train_data, val_data, test_data = Multi30k.splits(
    exts=('.de', '.en'), fields=(SRC, TRG))
train_iter, val_iter, test_iter = BucketIterator.splits(
    (train_data, val_data, test_data), batch_size=32)


SRC.build_vocab(train_data, min_freq=4)
TRG.build_vocab(train_data, min_freq=4)

In [12]:
# batch = next(iter(train_iter))

# for batch_idx, batch in enumerate(train_iter):
#     src = batch.src.transpose(0,1)[0:2]
#     trg = batch.trg.transpose(0,1)[0:2]
#     src = [' '.join(utils.itos(idx_seq, SRC)) for idx_seq in src]
#     trg = [' '.join(utils.itos(idx_seq, TRG)) for idx_seq in trg]
#     print(src)
#     print(trg)
    
#     if batch_idx == 0:
#         break

# print(len(train_iter))
# print(len(val_iter))
# print(len(test_iter))

# print(len(TRG.vocab))
# print(TRG.vocab.stoi[' '])
# print(TRG.vocab.itos[0])

## Define model

In [15]:
D_MODEL = 512
N_HEAD = 1
NUM_ENC_LAYERS = 1
NUM_DEC_LAYERS = 1
DIM_FEEDWORD = 64
DROPOUT = 0.5
ACTIVATION = 'relu'
N_EPOCH = 10

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
net = model.Transformer(device, len(SRC.vocab), len(TRG.vocab), D_MODEL, N_HEAD, NUM_ENC_LAYERS,
                        NUM_ENC_LAYERS, DIM_FEEDWORD, DROPOUT, ACTIVATION).to(device)
optimizer = optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=SRC.vocab.stoi['<pad>'])
writer = SummaryWriter(os.path.join(
    'log/', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))

In [16]:

# print(net)
# summary(net, torch.zeros((10,1), dtype = torch.long), torch.zeros((10,1), dtype = torch.long))

## Train and evaluate

In [17]:
best_val_bleu = 0.0
best_val_model = copy.deepcopy(net.state_dict())
for epoch in range(N_EPOCH):
    model.train(net, train_iter, criterion, optimizer, TRG, epoch, writer, device)
    val_loss, val_bleu = model.evaluate(net, val_iter, criterion, TRG, device)
    print(f'val loss: {val_loss:.3f} | val bleu: {val_bleu: .3f}')
    
    if val_bleu > best_val_bleu:
        best_val_bleu = val_bleu
        best_val_model = copy.deepcopy(net.state_dict())

net.load_state_dict(best_val_model)
test_loss, test_bleu = model.test(net, test_iter, criterion, TRG, device)
print(f'test loss: {test_loss:.3f} | test bleu: {test_bleu: .3f}')

val loss: 4.041 | val bleu:  0.024
val loss: 3.908 | val bleu:  0.032
val loss: 3.835 | val bleu:  0.016
val loss: 3.778 | val bleu:  0.039
val loss: 3.770 | val bleu:  0.026
val loss: 3.753 | val bleu:  0.012
val loss: 3.715 | val bleu:  0.036
val loss: 3.696 | val bleu:  0.026
val loss: 3.674 | val bleu:  0.034
val loss: 3.682 | val bleu:  0.037


TypeError: greedy_decoder() missing 1 required positional argument: 'trg'