# Transformer for neuron machine translation
Ref: https://andrewpeng.dev/transformer-pytorch/  
Ref: https://spaces.ac.cn/archives/6933  
Ref: https://github.com/graykode/nlp-tutorial  

In [1]:
import torch
from torchtext.datasets import Multi30k
from torchtext.data import BucketIterator, Field
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

import os
import time
import copy
from torchsummaryX import summary

import model
import utils

## Loading data

In [2]:
BSZ = 8
SRC = Field(tokenize='spacy',
            tokenizer_language='de',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

TRG = Field(tokenize='spacy',
            tokenizer_language='en',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

train_data, val_data, test_data = Multi30k.splits(
    exts=('.de', '.en'), fields=(SRC, TRG))
train_iter, val_iter, test_iter = BucketIterator.splits(
    (train_data, val_data, test_data), batch_size=BSZ)


SRC.build_vocab(train_data, min_freq=4)
TRG.build_vocab(train_data, min_freq=4)

In [3]:
# batch = next(iter(train_iter))

# for batch_idx, batch in enumerate(train_iter):
#     src = batch.src.transpose(0,1)[0:2]
#     trg = batch.trg.transpose(0,1)[0:2]
#     src = [' '.join(utils.itos(idx_seq, SRC)) for idx_seq in src]
#     trg = [' '.join(utils.itos(idx_seq, TRG)) for idx_seq in trg]
#     print(src)
#     print(trg)
    
#     if batch_idx == 0:
#         break

print(len(train_iter))
print(len(val_iter))
print(len(test_iter))

# print(len(TRG.vocab))
# print(TRG.vocab.stoi[' '])
# print(TRG.vocab.itos[0])

# print(SRC.vocab.stoi['<sos>'])
# print(TRG.vocab.stoi['<sos>'])

3625
127
125


## Define model

In [4]:
D_MODEL = 512
N_HEAD = 1
NUM_ENC_LAYERS = 1
NUM_DEC_LAYERS = 1
DIM_FEEDWORD = 64
DROPOUT = 0.5
ACTIVATION = 'relu'
N_EPOCH = 40
LR = 0.001

device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
net = model.Transformer(device, len(SRC.vocab), len(TRG.vocab), D_MODEL, N_HEAD, NUM_ENC_LAYERS,
                        NUM_ENC_LAYERS, DIM_FEEDWORD, DROPOUT, ACTIVATION).to(device)
optimizer = optim.Adam(net.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=SRC.vocab.stoi['<pad>'])
path = f'bsz:{BSZ}-lr:{LR}-epoch:{N_EPOCH}-d_model:{D_MODEL}-nhead:{N_HEAD}-nlayer:{NUM_ENC_LAYERS}\
-nhid:{DIM_FEEDWORD}-activation:{ACTIVATION}'
writer = SummaryWriter(os.path.join('log/', path))

best_val_bleu = 0.0
best_val_model = copy.deepcopy(net.state_dict())

In [5]:
# print(net)
# summary(net, torch.zeros((10,1), dtype = torch.long), torch.zeros((10,1), dtype = torch.long))

## Train and evaluate

In [6]:
for epoch in range(N_EPOCH):
    model.train(net, train_iter, criterion, optimizer, TRG, epoch, writer, device)
    val_loss, val_bleu = model.evaluate(net, val_iter, criterion, TRG, device)
    print(f'epoch: {epoch} | val loss: {val_loss:.3f} | val bleu: {val_bleu: .3f}')
    
    if val_bleu > best_val_bleu:
        best_val_bleu = val_bleu
        best_val_model = copy.deepcopy(net.state_dict())

epoch: 0 | val loss: 5.059 | val bleu:  0.000
epoch: 1 | val loss: 4.887 | val bleu:  0.000
epoch: 2 | val loss: 4.833 | val bleu:  0.000
epoch: 3 | val loss: 4.769 | val bleu:  0.000
epoch: 4 | val loss: 4.829 | val bleu:  0.000
epoch: 5 | val loss: 4.791 | val bleu:  0.000
epoch: 6 | val loss: 4.743 | val bleu:  0.000
epoch: 7 | val loss: 4.797 | val bleu:  0.000
epoch: 8 | val loss: 4.775 | val bleu:  0.000
epoch: 9 | val loss: 4.765 | val bleu:  0.000
epoch: 10 | val loss: 4.737 | val bleu:  0.000
epoch: 11 | val loss: 4.735 | val bleu:  0.000
epoch: 12 | val loss: 4.717 | val bleu:  0.000
epoch: 13 | val loss: 4.725 | val bleu:  0.000
epoch: 14 | val loss: 4.718 | val bleu:  0.000
epoch: 15 | val loss: 4.649 | val bleu:  0.001
epoch: 16 | val loss: 4.595 | val bleu:  0.001
epoch: 17 | val loss: 4.591 | val bleu:  0.001


KeyboardInterrupt: 

## Test

In [None]:
net.load_state_dict(best_val_model)
test_loss, test_bleu = model.test(net, test_iter, criterion, TRG, device)
print(f'test loss: {test_loss:.3f} | test bleu: {test_bleu: .3f}')