# Machine Translation with Transformer from Scratch

In [1]:
import torch
import torchtext
import torch.nn as nn
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import functional
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm
import sys
import random
import spacy 
import math
import copy
import numpy as np

### Prepare Data

In [2]:
de_tokenizer = get_tokenizer('spacy', 'de_core_news_lg')
en_tokenizer = get_tokenizer('spacy', 'en_core_web_lg')

def de_yield_tokens(train_iter):
    for de_text, _ in train_iter:
        yield de_tokenizer(de_text[:-1].lower())
        
def en_yield_tokens(train_iter):
    for _, en_text in train_iter:
        yield en_tokenizer(en_text[:-1].lower())
        
special_tokens = ["<pad>", "<unk>", "<sos>", "<eos>"]

train_iter = Multi30k(split=("train"))
source_vocab = build_vocab_from_iterator(de_yield_tokens(train_iter), 
                                         min_freq = 1,
                                         specials = special_tokens)
source_vocab.set_default_index(source_vocab["<unk>"])

train_iter = Multi30k(split=("train"))
target_vocab = build_vocab_from_iterator(en_yield_tokens(train_iter),
                                        min_freq = 1,
                                        specials = special_tokens)
target_vocab.set_default_index(target_vocab["<unk>"])

In [3]:
# Load data
train_iter, valid_iter, test_iter = Multi30k()

en_text_pipeline = lambda x:  target_vocab(["<sos>"] + en_tokenizer(x)[:-1] + ["<eos>"]) 
de_text_pipeline = lambda x: source_vocab(["<sos>"] + de_tokenizer(x)[:-1] + ["<eos>"])

BATCH_SIZE = 100

def collate_batch_input(batch):
    source_list, target_list = [], []
    for source_text, target_text in batch:  
        text_seq = de_text_pipeline(source_text.lower()) 
        source_list.append(torch.tensor(text_seq, dtype=torch.int64))
        text_seq = en_text_pipeline(target_text.lower())
        target_list.append(torch.tensor(text_seq[:-1], dtype=torch.int64))
               
    source_tensor = pad_sequence(source_list, batch_first=True, padding_value=source_vocab["<pad>"])
    target_tensor = pad_sequence(target_list, batch_first=True, padding_value=target_vocab["<pad>"])
    
    return source_tensor, target_tensor, 


train_dataset = functional.to_map_style_dataset(train_iter)
valid_dataset = functional.to_map_style_dataset(valid_iter)
test_dataset = functional.to_map_style_dataset(test_iter)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE,
                          shuffle = True, collate_fn = collate_batch_input)
valid_loader = DataLoader(valid_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)

## Model Architecture

In [74]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.generator = generator
        
    def forward(self, src, trg, src_mask, trg_mask):
        "Take in and process masked src and target sequences."
        dec_out = self.decode(trg, self.encode(src, src_mask), trg_mask, src_mask)
        return self.generator(dec_out)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, trg, enc_src, trg_mask, src_mask):
        return self.decoder(self.trg_embed(trg), enc_src, trg_mask, src_mask)


<img src="img/transformer1.png"/>

### Embeddings

In [75]:
class Embeddings(nn.Module):
    """ Combining word embedding and positional embedding"""
    def __init__(self, vocab, d_model, pad_index):
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model, padding_idx = pad_index)
        self.d_model = d_model

    def forward(self, x):
        e = self.embedding(x)
        pos_emb = torch.zeros(e.size(0), e.size(1), e.size(2))
        position = torch.arange(0, x.size(1)).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) *
                             -(math.log(10000.0) / self.d_model))
        pos_emb[:, :, 0::2] = torch.sin(position * div_term)
        pos_emb[:, :, 1::2] = torch.cos(position * div_term)
        return  e + pos_emb

         

In [76]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

## Encoder and Decoder Stacks

### Encoder

In [77]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [78]:
class Encoder(nn.Module):
    """
    Stack of N EncoderLayers
    """
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        
    def forward(self, src, src_mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

In [79]:
class EncoderBlock(nn.Module):
    """
    Encoder is made up of self-attn and feed forward.
    """
    def __init__(self, size, self_attn, feed_forward):
        super(EncoderBlock, self).__init__()
        self.layer_norm_1 = nn.LayerNorm(size)
        self.layer_norm_2 = nn.LayerNorm(size)
        self.self_attn = self_attn
        self.feedforward = feed_forward
        
    def forward(self, src, src_mask):
        out = self.layer_norm_1(src + self.self_attn(src, src, src, src_mask))
        return self.layer_norm_2(out + self.feedforward(out))

### Attention

In [80]:
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model):
        super(MultiHeadAttention, self).__init__()
        self.Q_linear = nn.Linear(d_model, d_model)
        self.K_linear = nn.Linear(d_model, d_model)
        self.V_linear = nn.Linear(d_model, d_model)
        self.lin_out = nn.Linear(d_model, d_model)
        self.h = h
        self.d_model = d_model
        self.d_k = d_model // h
        
    def forward(self, query, key, value, mask=None):  
        Q = self.Q_linear(query)
        K = self.Q_linear(key)
        V = self.Q_linear(value)
        
        n_batches = Q.size(0)
    
        # Splitting in attention heads
        Q = Q.view(n_batches, -1, self.h, self.d_k)
        K = K.view(n_batches, -1, self.h, self.d_k)
        V = V.view(n_batches, -1, self.h, self.d_k)
        Q = Q.transpose(1,2)
        K = K.transpose(1,2)
        V = V.transpose(1,2)
        x = self.compute_attention(Q, K, V, mask=mask)
        # concatenating attention heads
        x = x.transpose(1,2).contiguous().view(n_batches, -1, self.h*self.d_k)
        return self.lin_out(x)
        
    @staticmethod
    def compute_attention(query, key, value, mask):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2,-1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask==0, -np.inf)
        p_attn = F.softmax(scores, dim=-1)
        return torch.matmul(p_attn, value)

https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853

In [81]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.w_2(F.relu(self.w_1(x)))

### Decoder

In [82]:
class Decoder(nn.Module):
    """
    Stack of N DecoderBlocks
    """
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        return trg

In [83]:
class DecoderBlock(nn.Module):
    """
    Decoder is made up of self-attn, src_attn and feed forward.
    """
    def __init__(self, size, self_attn, src_attn, feed_forward):
        super(DecoderBlock, self).__init__()
        self.layer_norm_1 = nn.LayerNorm(size)
        self.layer_norm_2 = nn.LayerNorm(size)
        self.layer_norm_3 = nn.LayerNorm(size)
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feedforward = feed_forward
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        out = self.layer_norm_1(trg + self.self_attn(trg, trg, trg, trg_mask))
        out = self.layer_norm_2(out + self.src_attn(trg, enc_src, enc_src, src_mask))
        return self.layer_norm_3(out + self.feedforward(out))

In [84]:
def make_model(d_model, h, src_vocab_size, trg_vocab_size, src_pad_idx, 
               trg_pad_idx, n_layers):
    src_embedding = Embeddings(src_vocab_size, d_model, src_pad_index)
    trg_embedding = Embeddings(trg_vocab_size, d_model, trg_pad_idx)
    generator = Generator(d_model, len(target_vocab))
    attention = MultiHeadAttention(h, d_model)
    ff = PositionwiseFeedForward(64, 128)
    enc_block = EncoderBlock(d_model, attention, ff)
    encoder = Encoder(enc_block, n_layers)
    dec_block = DecoderBlock(d_model, attention, attention, ff)
    decoder = Decoder(dec_block, n_layers)
    transformer = EncoderDecoder(encoder, decoder, src_embedding, trg_embedding, generator)
    return transformer

def make_src_mask(src_pad_index, src):        
        #src = [batch size, src len]
        
        src_mask = (src != src_pad_index).unsqueeze(1).unsqueeze(2)
        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
def make_trg_mask(trg_pad_index, trg):        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != trg_pad_index).unsqueeze(1).unsqueeze(2)       
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]      
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool()     
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask    
        #trg_mask = [batch size, 1, trg len, trg len]

In [85]:
a = iter(train_loader)
a = next(a)
d_model = 64
h = 4
src_vocab_size = len(source_vocab)
trg_vocab_size = len(target_vocab)
src_pad_index = source_vocab["<pad>"]
trg_pad_index = target_vocab["<pad>"]
n_layers = 3

transformer = make_model(d_model, h, src_vocab_size, trg_vocab_size, 
                         src_pad_index, trg_pad_index, n_layers)

In [86]:
def count_parameters(transformer):
    return sum(p.numel() for p in transformer.parameters() if p.requires_grad)

print(f'The model has {count_parameters(transformer):,} trainable parameters')

The model has 2,659,587 trainable parameters


In [87]:
src = a[0]
trg = a[1]
src_mask = make_src_mask(src_pad_index, src)
trg_mask = make_trg_mask(trg_pad_index, trg)

### Train Model

In [88]:
LEARNING_RATE = 0.0005

optimizer = torch.optim.Adam(transformer.parameters(), lr = LEARNING_RATE)

In [89]:
criterion = nn.CrossEntropyLoss(ignore_index = trg_pad_index)

In [90]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, (src, trg) in enumerate(iterator):
        
        src_mask = make_src_mask(src_pad_index, src)
        trg_mask = make_trg_mask(trg_pad_index, trg[:,:-1])
    
        optimizer.zero_grad()
        
        output = model(src, trg[:,:-1], src_mask, trg_mask)
                
        #output = [batch size, trg len - 1, output dim]
        #trg = [batch size, trg len]
            
        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)
                
        #output = [batch size * trg len - 1, output dim]
        #trg = [batch size * trg len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [98]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, (src, trg) in enumerate(iterator):

            src_mask = make_src_mask(src_pad_index, src)
            trg_mask = make_trg_mask(trg_pad_index, trg[:,:-1])

            output = model(src, trg[:,:-1], src_mask, trg_mask)
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [99]:
import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [102]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(transformer, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(transformer, valid_loader, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(transformer.state_dict(), 'tut6-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 8m 6s
	Train Loss: 3.017 | Train PPL:  20.434
	 Val. Loss: 2.875 |  Val. PPL:  17.728
Epoch: 02 | Time: 8m 1s
	Train Loss: 2.654 | Train PPL:  14.214
	 Val. Loss: 2.571 |  Val. PPL:  13.085
Epoch: 03 | Time: 8m 53s
	Train Loss: 2.327 | Train PPL:  10.249
	 Val. Loss: 2.289 |  Val. PPL:   9.868
Epoch: 04 | Time: 8m 4s
	Train Loss: 2.016 | Train PPL:   7.510
	 Val. Loss: 2.014 |  Val. PPL:   7.493
Epoch: 05 | Time: 8m 33s
	Train Loss: 1.722 | Train PPL:   5.598
	 Val. Loss: 1.761 |  Val. PPL:   5.819
Epoch: 06 | Time: 8m 44s
	Train Loss: 1.461 | Train PPL:   4.312
	 Val. Loss: 1.552 |  Val. PPL:   4.722
Epoch: 07 | Time: 8m 34s
	Train Loss: 1.236 | Train PPL:   3.441
	 Val. Loss: 1.376 |  Val. PPL:   3.958
Epoch: 08 | Time: 9m 12s
	Train Loss: 1.043 | Train PPL:   2.838
	 Val. Loss: 1.210 |  Val. PPL:   3.352
Epoch: 09 | Time: 8m 50s
	Train Loss: 0.880 | Train PPL:   2.411
	 Val. Loss: 1.096 |  Val. PPL:   2.991
Epoch: 10 | Time: 9m 9s
	Train Loss: 0.744 | Train PPL:   

#### Reference
1. https://github.com/bentrevett
2. https://jalammar.github.io/illustrated-transformer/
3. https://medium.com/analytics-vidhya/masking-in-transformers-self-attention-mechanism-bad3c9ec235c