## Chapter 12 - Example Translation Transformer

Subclasses nn.Transformer for German-to-English translation task.
Much of the driver code including for training, evaluation and displaying results is credited to Ben Trevett's 
"Attention Is All You Need" notebook: 
  https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

### Data and Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

from einops import rearrange

import random
import math
import time

### Prepare the Data

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
! python -m spacy download en
! python -m spacy download de

In [None]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [None]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

### TODO: torchtext.Field is deprecated.

The torchtext is undergoing a lot of changes and Field will be removed in next release:
https://github.com/pytorch/text/issues/664
Refactor with torch.experimental (once released?) or manually create own torch Dataset

In [None]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True)

In [None]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), 
                                                    fields = (SRC, TRG))

In [None]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
BATCH_SIZE = 128

# TODO: refactor. BucketIterator paradigm is going away.

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     device = device)

### Build TranslationTransformer model from pytorch nn.transformer modules

In [None]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
from torch import Tensor
from typing import Optional, Any

class CustomDecoderLayer(nn.TransformerDecoderLayer):

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """Same as DecoderLayer but returns multi-head attention weights.
        """
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2, attention_weights = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask, need_weights=True)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt, attention_weights
    
    
class CustomDecoder(nn.TransformerDecoder):

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(CustomDecoder, self).__init__(decoder_layer, num_layers, norm)
        

    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """Same as TransformerDecoder except caches the multi-head attention output from each decoder layer.
        """
        self.attention_weights = []

        output = tgt
        for mod in self.layers:
            output, attention = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)
            
            # save the attention weights from this decoder layer
            self.attention_weights.append(attention)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [None]:

class TranslationTransformer(nn.Transformer):
    def __init__(self, device: str, src_vocab_size: int, src_pad_idx: int, 
                 tgt_vocab_size: int, tgt_pad_idx: int, max_sequence_length: int = 100,
                 d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048,
                 dropout: float = 0.1, activation: str = "relu"):
        
        decoder_layer = CustomDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        decoder_norm = nn.LayerNorm(d_model)
        decoder = CustomDecoder(decoder_layer, num_decoder_layers, decoder_norm)
        
        super(TranslationTransformer, self).__init__(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                                    num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                                    dropout=dropout, custom_decoder=decoder)
        
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device
        
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
                 
        
        self.pos_enc = PositionalEncoding(d_model, dropout, max_sequence_length)
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        
    def init_weights(self):
        def _init_weights(m):
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                nn.init.xavier_uniform_(m.weight.data)
        self.apply(_init_weights);

    def _make_key_padding_mask(self, t, pad_idx):
        mask = (t == pad_idx).to(self.device)
        
        return mask
    
    def prepare_src(self, src, src_pad_idx):
        src_key_padding_mask = self._make_key_padding_mask(src, src_pad_idx)
        src = rearrange(src, 'N S -> S N')
        src = self.pos_enc(self.src_emb(src) * math.sqrt(self.d_model))
        
        return src, src_key_padding_mask
    
    def prepare_tgt(self, tgt, tgt_pad_idx):
        tgt_key_padding_mask = self._make_key_padding_mask(tgt, tgt_pad_idx)
        tgt = rearrange(tgt, 'N T -> T N')
        tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0]).to(self.device)
        tgt = self.pos_enc(self.tgt_emb(tgt) * math.sqrt(self.d_model))
        
        return tgt, tgt_key_padding_mask, tgt_mask

    def forward(self, src, tgt):
        src, src_key_padding_mask = self.prepare_src(src, self.src_pad_idx)
       
        tgt, tgt_key_padding_mask, tgt_mask = self.prepare_tgt(tgt, self.tgt_pad_idx)
                                                      
        memory_key_padding_mask = src_key_padding_mask.clone()

        output = super(TranslationTransformer, self).forward(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        
        output = rearrange(output, 'T N E -> N T E')
        
        return self.linear(output)

### Initialize the model

In [None]:
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = TranslationTransformer(device=device,
                               src_vocab_size=len(SRC.vocab), src_pad_idx=SRC_PAD_IDX,
                               tgt_vocab_size=len(TRG.vocab), tgt_pad_idx=TRG_PAD_IDX).to(device)
model.init_weights()

#### Quick sanity check that forward pass of the model works

In [None]:
src = torch.randint(1, 100, (10, 5)).to('cuda')
tgt = torch.randint(1, 100, (10, 7)).to('cuda')

model.eval()
with torch.no_grad():
    output = model(src, tgt)
    
print(output.shape)

###   Setup training parameters

In [None]:
LEARNING_RATE = 0.0001

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

### train and evaluate functions

In [None]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg      
        
        optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])
            
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

### Do the training

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
BEST_MODEL_FILE = 'best_model.pytorch'

In [None]:
N_EPOCHS = 15
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), BEST_MODEL_FILE)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

In [None]:
model.load_state_dict(torch.load(BEST_MODEL_FILE))

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

## Inference


In [None]:
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
    
    model.eval()
        
    if isinstance(sentence, str):
        nlp = spacy.load('de')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
        
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]

    src = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    src, src_key_padding_mask = model.prepare_src(src, SRC_PAD_IDX)
  
    
    with torch.no_grad():
        enc_src = model.encoder(src, src_key_padding_mask=src_key_padding_mask)

    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]

    for i in range(max_len):

        tgt = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

        tgt, tgt_key_padding_mask, tgt_mask = model.prepare_tgt(tgt, TRG_PAD_IDX)
        
        
        with torch.no_grad():
            output = model.decoder(tgt, enc_src, tgt_mask=tgt_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=src_key_padding_mask)
            
            
            output = rearrange(output, 'T N E -> N T E')
            output = model.linear(output)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)
        
        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break

    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]    
    translation = trg_tokens[1:]

    
    return translation, model.decoder.attention_weights

In [None]:
def display_attention(sentence, translation, attention_weights):
    n_attention = len(attention_weights)
    
    n_cols = 2
    n_rows = n_attention // n_cols + n_attention % n_cols

    
    fig = plt.figure(figsize=(15,25))
    
    for i in range(n_attention):
        
        ax = fig.add_subplot(n_rows, n_cols, i+1)
        
        attention = attention_weights[i].squeeze(0).cpu().detach().numpy()

        cax = ax.matshow(attention, cmap='gist_yarg')

        ax.tick_params(labelsize=12)
        ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], 
                           rotation=45)
        ax.set_yticklabels(['']+translation)

        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()

### Validation data inference example

In [None]:
example_idx = 25

src = vars(valid_data.examples[example_idx])['src']
trg = vars(valid_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

In [None]:
translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f'translation = {translation}')

In [None]:
display_attention(src, translation, attention)

### Test data inference example

In [None]:
example_idx = 10

src = vars(test_data.examples[example_idx])['src']
trg = vars(test_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

In [None]:
translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f'translation = {translation}')

In [None]:
display_attention(src, translation, attention)

### Calculate BLEU score

In [None]:
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
    
    trgs = []
    pred_trgs = []
    
    for datum in data:
        
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        
        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
        
        # strip <eos> token
        pred_trg = pred_trg[:-1]
        
        pred_trgs.append(pred_trg)
        trgs.append([trg])
        
    return bleu_score(pred_trgs, trgs)

In [None]:
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)

print(f'BLEU score = {bleu_score*100:.2f}')