This is implementation of transformer from the original paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762). The task in this implementation is English-Dutch translation.

**Input embedding**

In [None]:
import torch
import torch.nn as nn
import math

In [None]:
class InputEmbedding(nn.Module):
    
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model  # In the paper they use vectors of size d_model=512
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

**Positional encoiding**

In [None]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model : int, seq_len : int, dropout_probability : float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        # To avoid overfitting
        self.dropout = nn.Dropout(dropout_probability)
        
        pos_encoding = torch.zeros(seq_len, d_model) # torch.Size([seq_len, d_model])
        positions = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # torch.Size([seq_len, 1])
        division_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                                  (-math.log(10000.0) / d_model)) # torch.Size([d_model, 1])
        
        # Sine for even indices
        pos_encoding[:, 0::2] = torch.sin(positions * division_term)
        # Cosine for odd indices
        pos_encoding[:, 1::2] = torch.cos(positions * division_term)
        # Add batch dimension
        pos_encoding = pos_encoding.unsqueeze(0)
        # Save positional encoiding like parameters
        self.register_buffer('pos_encoding', pos_encoding)
        
    def forward(self, x):
        # We don't want to spend time learning pos_encoding since it is fixed, so requires_grad_(False)
        # torch.Size([batch, seq_len, d_model])
        x = x + (self.pos_encoding[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)
        

**Layer normalization**

In [None]:
class LayerNormalization(nn.Module):
    
    def __init__(self, eps : float = 10**-6):
        super().__init__()
        # eps is to prevent dividing by zero or when std is very small
        self.eps = eps
        # Adding
        self.bias = nn.Parameter(torch.zeros(1))
        # Multiplying
        self.alpha = nn.Parameter(torch.ones(1))
        
    def forward(self, x):
        # x : torch.Size([batch, seq_len, hidden_size])
        mean = x.mean(dim=-1, keepdim=True) # torch.Size([batch, seq_len, 1])
        std = x.std(dim=-1, keepdim=True) # torch.Size([batch, seq_len, 1])
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

**Feed forward**

In [None]:
class FeedForward(nn.Module):
    
    def __init__(self, d_model : int, d_inner_later : int, dropout_prob : float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_prob)
        # d_inner_later = 2048 in paper
        self.lin_layer_1 = nn.Linear(d_model, d_inner_later)
        self.lin_layer_2 = nn.Linear(d_inner_later, d_model)
        
    def forward(self, x):
        # x : (batch, seq_len, d_model)
        x = self.lin_layer_1(x)
        # x : (batch, seq_len, d_inner_later)
        x = torch.relu(x)
        x = self.dropout(x)
        # x : (batch, seq_len, d_model)
        x = self.lin_layer_2(x)
        return x

**Multi-Head Attention**

In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model : int, num_heads : int, dropout_prob : float):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout_prob)

        assert d_model % num_heads == 0
        
        self.d_k = d_model // num_heads # Dimension of vector in each head
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    @staticmethod
    def attention(Q_k, K_k, V_k, mask=None, dropout: nn.Dropout = None):
        d_k = Q_k.shape[-1]
        # shape : (batch, h, seq_len, d_k) -> (batch, h, seq_len, seq_len)
        attention_scores = (Q_k @ K_k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # shape : (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        return attention_scores @ V_k, attention_scores
    
    def forward(self, Q, K, V, mask=None):
        # shape: (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        query = self.W_q(Q)
        key = self.W_k(K)
        value = self.W_v(V)
        
        # shape: (batch, seq_len, d_model) 
        # -> (batch, seq_len, num_heads, d_k) 
        # -> (batch, num_heads, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.d_k).transpose(1, 2)
        
        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
        
        # Concat
        # shape : (batch, num_heads, seq_len, d_k) 
        # -> (batch, seq_len, num_heads, d_k) 
        # -> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.d_model)
        
        # shape : (batch, seq_len, d_model) -> (batch, seq_len, d_model)  
        return self.W_o(x)

**Residuals**

In [None]:
class Residual(nn.Module):
    
    def __init__(self, dropout_prob : float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_prob)
        self.layer_norm = LayerNormalization()
        
    def forward(self, x, sublayer):
        return x + self.dropout(self.layer_norm(sublayer(x)))

**Encoder Block**

In [None]:
# There is N encoder blocks
class EncoderBlock(nn.Module):
    
    def __init__(self, self_attention_block: MultiHeadAttention, 
                 feed_forward_block: FeedForward, dropout_prob : float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residuals = nn.ModuleList([Residual(dropout_prob) for _ in range(2)])
    
    def forward(self, x, src_mask):
        x = self.residuals[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residuals[1](x, lambda x: self.feed_forward_block(x))
        return x

**Encoder**

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, layers : nn.ModuleList):
        super().__init__()
        self.layers = layers
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return x

**Decoder Block**

In [None]:
class DecoderBlock(nn.Module):
    
    def __init__(self, self_attention_block: MultiHeadAttention, 
                 cross_attention_block: MultiHeadAttention, 
                 feed_forward_block: FeedForward, dropout_prob : float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residuals = nn.ModuleList([Residual(dropout_prob) for _ in range(3)])
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residuals[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residuals[1](x, lambda x: self.cross_attention_block(x, encoder_output, 
                                                                      encoder_output, src_mask))
        x = self.residuals[2](x, lambda x: self.feed_forward_block(x))
        return x

**Decoder**

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, layers : nn.ModuleList):
        super().__init__()
        self.layers = layers
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return x

**Linear Classifier**

In [None]:
class LinearClassifier(nn.Module):
    
    def __init__(self, d_model : int, vocab_size : int):
        super().__init__()
        self.lin_layer = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        x = self.lin_layer(x)
        x = torch.nn.functional.softmax(x)
        return x

**Transformer**

In [None]:
class Transformer(nn.Module):
    
    def __init__(self, encoder: Encoder, decoder: Decoder, 
                 src_embedding: InputEmbedding, tgt_embedding: InputEmbedding,
                 src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, 
                 lin_classifier: LinearClassifier):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embedding = src_embedding
        self.tgt_embedding = tgt_embedding
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.lin_classifier = lin_classifier
    
    def encode(self, src, src_mask):
        src = self.src_embedding(src)
        src = self.src_pos(src)
        src = self.encoder(src, src_mask)
        return src
    
    def decode(self, tgt, encoder_output, tgt_mask, src_mask):
        tgt = self.tgt_embedding(tgt)
        tgt = self.tgt_pos(tgt)
        tgt = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
        return tgt
    
    def project(self, x):
        return self.lin_classifier(x)

In [None]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int,
                      d_model: int = 512, n_layers: int =  6, num_heads: int = 8,
                      dropout_prob: float = 0.1, d_inner_later_ff: int = 2048) -> Transformer:
    src_embedding = InputEmbedding(d_model, src_vocab_size)
    tgt_embedding = InputEmbedding(d_model, tgt_vocab_size)
    
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout_prob)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout_prob)
    
    encoder_self_attention_block = MultiHeadAttention(d_model, num_heads, dropout_prob)
    encoder_feed_forward_block = FeedForward(d_model, d_inner_later_ff, dropout_prob)
    encoder_layers = nn.ModuleList([EncoderBlock(
        encoder_self_attention_block, encoder_feed_forward_block, dropout_prob) for _ in range(n_layers)])
    
    decoder_self_attention_block = MultiHeadAttention(d_model, num_heads, dropout_prob)
    decoder_cross_attention_block = MultiHeadAttention(d_model, num_heads, dropout_prob)
    decoder_feed_forward_block = FeedForward(d_model, d_inner_later_ff, dropout_prob)
    decoder_layers = nn.ModuleList([DecoderBlock(
        decoder_self_attention_block, decoder_cross_attention_block, 
        decoder_feed_forward_block, dropout_prob) for _ in range(n_layers)])
    
    encoder = Encoder(encoder_layers)
    decoder = Decoder(decoder_layers)
    
    lin_classifier = LinearClassifier(d_model, tgt_vocab_size)

    t = Transformer(encoder, decoder, src_embedding, tgt_embedding, src_pos, tgt_pos, lin_classifier)
    
    # Initialize starting parameters
    for p in t.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return t

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, 
                              config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

**Tokenize**

In [None]:
import os
from torch.utils.data import Dataset, DataLoader, random_split

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = config['tokenizer_file'].format(lang)
    if not os.path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(tokenizer_path)
    else:
        tokenizer = Tokenizer.from_file(tokenizer_path)
    return tokenizer

**Dataset**

In [None]:
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        
        # Tokens
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        src_data = self.ds[idx]
        src_text = src_data['translation'][self.src_lang]
        tgt_text = src_data['translation'][self.tgt_lang]
        
        # Split sentence and then give every word its id number
        # input_tokens : list(int)
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
        
        # Add tokens
        # Here -2 is for [SOS] and [EOS]
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        # Here -1 is for [SOS]
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence too long")
        
        # Concat all the tokens
        encoder_input = torch.cat([
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)], dim=0)
        
        decoder_input = torch.cat([
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64), 
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)], dim=0)
        
        target = torch.cat([
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)], dim=0)
        
        # size : (1, 1, seq_len)
        # .unsqueeze(0).unsqueeze(0) -> add seq_len dim and add batch dim
        encoder_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int()
        # size: (1, 1, seq_len) & (1, seq_len, seq_len)
        decoder_mask = (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0))
        
        return {
            "src_text": src_text,
            "tgt_text": tgt_text,
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "target": target,  # (seq_len)
            "encoder_mask": encoder_mask, # (1, 1, seq_len)
            "decoder_mask": decoder_mask, # (1, 1, seq_len) & (1, seq_len, seq_len),
        }

In [None]:
def get_ds(config):
    ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
    
    if config["limit_ds"] is not None:
        ds_raw = list(ds_raw)[:config["limit_ds"]]
    
    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    
    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

**Configuration**

In [None]:
config = {
    "batch_size": 8,
    "num_epochs": 2,
    "lr": 10**-4,
    "seq_len": 750,
    "d_model": 512,
    "lang_src": "en",
    "lang_tgt": "nl",
    "model_folder": "weights",
    "model_basename": "my_first_run",
    "tokenizer_file": "tokenizer_{0}.json",
    "experiment_name": "runs/tmodel",
    "limit_ds": 300, # if it is None then it is not limited and uses the whole ds
    "validation_num_examples": 2
}

def get_weights_file_path(config):
    model_folder = config["model_folder"]
    model_basename = config["model_basename"]
    model_filename = f"{model_basename}.pt"
    return os.path.join("/kaggle", "working", model_folder, model_filename)

**Validation**

In [None]:
def decode_completely(config, model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, device):
    sos_idx = tokenizer_src.token_to_id('[SOS]')
    eos_idx = tokenizer_src.token_to_id('[EOS]')
    
    encoder_output = model.encode(encoder_input, encoder_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(encoder_input).to(device)
    
    while True:
        if decoder_input.size(1) >= config["seq_len"]:
            break
        
        # Make a mask
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)
        
        # calculate output
        model_out = model.decode(decoder_input, encoder_output, decoder_mask, encoder_mask)
        proj = model.project(model_out[:, -1])
        prob, next_word_idx = torch.max(proj, dim=1)

        decoder_input = torch.cat([
            decoder_input, 
            torch.empty(1, 1).type_as(encoder_input).fill_(next_word_idx.item()).to(device)], 
            dim=1)

        if next_word_idx == eos_idx:
            break
            
    return decoder_input.squeeze(0)
    

def validate(config, model, val_dataloader, tokenizer_src, tokenizer_tgt, device):
    model.eval()
    count = 0
    
#     source_texts = []
#     target_texts = []
#     predicted_texts = []
    
    with torch.no_grad():
        for batch in val_dataloader:
            count += 1
            
            encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            
            model_out = decode_completely(config, model, encoder_input, encoder_mask, 
                                          tokenizer_src, tokenizer_tgt, device)
            
            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            predicted_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
            
#             source_texts.append(source_text)
#             target_texts.append(target_text)
#             predicted_texts.append(predicted_text)
            
            # Log results
            print("-----------------------------")
            print("SOURCE:", source_text)
            print("TARGET:", target_text)
            print("PREDICTED:", predicted_text)
            
            if count == config["validation_num_examples"]:
                break

**Train**

In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)
    
    if not os.path.exists(config["model_folder"]):
        os.makedirs(config["model_folder"])
        print("Created:", config["model_folder"])
    else:
        print(config["model_folder"])
    
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]")).to(device)
    # TensorBoard
    writer = SummaryWriter(config['experiment_name'])
    
    for epoch in tqdm(range(config["num_epochs"])):
        for batch in tqdm(train_dataloader):
            model.train()
            
            encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
            
            # size : (batch, seq_len, d_model)
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(decoder_input, encoder_output, decoder_mask, encoder_mask)
            # size : (batch, seq_len, tgt_vocab_size)
            proj_output = model.project(decoder_output)
            
            # size : (batch, seq_len)
            target = batch['target'].to(device)
            
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), target.view(-1))
            
            # Log the loss
            writer.add_scalar('train loss', loss.item(), epoch)
            writer.flush()
            
            # Backpropagation
            loss.backward()
            
            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
        # Validate
        validate(config, model, val_dataloader, tokenizer_src, tokenizer_tgt, device)
        
        # Save the model after every epoch
        model_filename = get_weights_file_path(config)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, model_filename)

In [None]:
train_model(config)