__Description__:

    The code follows the architecture of "NEURAL MACHINE TRANSLATION 
    BY JOINTLY LEARNING TO ALIGN AND TRANSLATE", which is Bidirectional
    GRU encoder + unidirectional GRU decoder + Bahdanau attention

In [1]:
import torch
from torchtext.data import Field, BucketIterator
from torchtext.datasets import Multi30k
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import time 
import os
from torchsummaryX import summary
from torchtext.data.metrics import bleu_score

import NMT

## Utils

In [2]:
def count_bleu(outputs, trg, TRG):
    '''
        outputs: [T, N, E]
        trg: [T, N]
    '''
    outputs = outputs.permute(1,0,2).max(-1)[1]
    trg = trg.permute(1,0)
    mask = trg.ne(TRG.vocab.stoi['<pad>'])

    outputs = outputs.masked_select(mask)
    trg = trg.masked_select(mask)
    
    candidates = [[TRG.vocab.itos[i] for i in outputs]]
    references = [[[TRG.vocab.itos[i] for i in trg]]]
    return bleu_score(candidates, references)


def init_weights(m: nn.Module):
    '''
        initialize weights by normal distribution,
        bias by constant distribution
    '''
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

## Load dataset

In [3]:
BSZ = 128
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

SRC = Field(tokenize='spacy',
            tokenizer_language='de',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

TRG = Field(tokenize='spacy',
            tokenizer_language='en',
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)

train_data, dev_data, test_data = Multi30k.splits(
    exts=('.de', '.en'),
    fields=(SRC, TRG), root='.data', train='train',
    validation='val', test='test2016')

SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, dev_data, test_data),
    batch_size=BSZ, device=device)

In [4]:
# def itos(idx_seq, field):
#     return [field.vocab.itos[idx] for idx in idx_seq]

# batch = next(iter(train_ietrator))
# src = batch.src.transpose(0,1)
# trg = batch.trg.transpose(0,1)
# src = [' '.join(itos(idx_seq, SRC)) for idx_seq in src]
# trg = [' '.join(itos(idx_seq, TRG)) for idx_seq in trg]
# print(src)
# print(trg)

print(len(train_iterator))
print(len(valid_iterator))
print(len(test_iterator))

227
8
8


## Define model

In [5]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
ENC_HID_DIM = 64
DEC_HID_DIM = 64
ATTN_DIM = 8
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
N_EPOCHS = 29
CLIP = 1
LR = 1e-3


enc = NMT.Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
atten = NMT.Attention(ENC_HID_DIM, DEC_HID_DIM, ATTN_DIM)
dec = NMT.Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, atten)
model = NMT.Seq2Seq(enc, dec).to(device)
model.apply(init_weights)
optimizer = optim.Adam(model.parameters(), lr=LR)
PAD_IDX = TRG.vocab.stoi['<pad>']
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)
path = f'bsz:{BSZ}-nepoch:{N_EPOCHS}-lr:{LR}-nemd:{ENC_EMB_DIM}-nhid:{ENC_EMB_DIM}\
-nattn:{ATTN_DIM}-dropout:{ENC_DROPOUT}-clip:{CLIP}'
writer = SummaryWriter(os.path.join('./log', path))

## Define train

In [6]:
def train(model, criterion, iterator, optimizer, clip, epoch, TRG, writer, device):
    model.train()
    running_loss = 0
    running_bleu = 0
    for batch_idx, batch in enumerate(iterator):
        src = batch.src.to(device)
        trg = batch.trg.to(device)
        optimizer.zero_grad()
        outputs = model(src, trg, device)

        loss = criterion(outputs[1:,...].view(-1, outputs.shape[-1]), trg[1:,:].view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        running_loss += loss.item()
        running_bleu += count_bleu(outputs[1:,...], trg[1:,:], TRG)
        

        if batch_idx % 27 == 26:
            writer.add_scalar('train loss',
                              running_loss/27,
                              epoch*len(iterator)+batch_idx)

            writer.add_scalar('train BLEU',
                              running_bleu/27,
                              epoch*len(iterator)+batch_idx)

            running_bleu = 0
            running_loss = 0

## Define evaluate

In [7]:
def evaluate(model, criterion, iterator, epoch, TRG, writer, device):
    model.eval()
    running_loss = 0
    running_bleu = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(iterator):
            src = batch.src.to(device)
            trg = batch.trg.to(device)
            # set teacher forcing ratio = 0 
            outputs = model(src, trg, device, 0)

            loss = criterion(outputs[1:,...].view(-1, outputs.shape[-1]), trg[1:,:].view(-1))
            running_loss += loss.item()
            running_bleu += count_bleu(outputs[1:,...], trg[1:,:], TRG)

            if batch_idx % 1 == 0:
                writer.add_scalar('test loss',
                                  running_loss/1,
                                  epoch*len(iterator)+batch_idx)

                writer.add_scalar('test BLEU',
                                  running_bleu/1,
                                  epoch*len(iterator)+batch_idx)

                running_loss = 0
                running_bleu = 0

## Train and evaluate

In [8]:
for epoch in range(N_EPOCHS):
    %time train(model, criterion, train_iterator, optimizer, CLIP, epoch, TRG, writer, device)
    %time evaluate(model, criterion, valid_iterator, epoch, TRG, writer, device)
    print(epoch)

CPU times: user 1min 32s, sys: 1min, total: 2min 33s
Wall time: 58.2 s
CPU times: user 2.76 s, sys: 1.14 s, total: 3.9 s
Wall time: 1.29 s
0
CPU times: user 1min 31s, sys: 1min 3s, total: 2min 35s
Wall time: 57 s
CPU times: user 2.38 s, sys: 1.35 s, total: 3.73 s
Wall time: 1.12 s
1
CPU times: user 1min 36s, sys: 56.5 s, total: 2min 32s
Wall time: 59.2 s
CPU times: user 2.64 s, sys: 1.17 s, total: 3.81 s
Wall time: 1.13 s
2
CPU times: user 1min 35s, sys: 59.4 s, total: 2min 35s
Wall time: 1min
CPU times: user 2.15 s, sys: 1.64 s, total: 3.78 s
Wall time: 1.05 s
3
CPU times: user 1min 34s, sys: 1min 2s, total: 2min 37s
Wall time: 1min
CPU times: user 2.08 s, sys: 1.7 s, total: 3.78 s
Wall time: 1.13 s
4
CPU times: user 1min 36s, sys: 1min, total: 2min 37s
Wall time: 1min 1s
CPU times: user 2.61 s, sys: 1.2 s, total: 3.81 s
Wall time: 1.21 s
5
CPU times: user 1min 38s, sys: 1min 1s, total: 2min 39s
Wall time: 1min 3s
CPU times: user 2.53 s, sys: 1.26 s, total: 3.79 s
Wall time: 1.13 s
6
