# Neural Machine Translation Using Transformers

In [None]:
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader

import numpy as np

## Gathering Data and preparing the dataset

In [None]:
import torchtext
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import io

In [None]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.de.gz', 'train.en.gz')
val_urls = ('val.de.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

In [None]:
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [None]:
def build_vocab(filepaths, tokenizer):
  counter = Counter()
  for filepath in filepaths:
    with io.open(filepath, encoding="utf8") as f:
      for string_ in f:
        counter.update(tokenizer(string_))
  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

de_vocab = build_vocab([train_filepaths[0], val_filepaths[0], test_filepaths[0]], de_tokenizer)
en_vocab = build_vocab([train_filepaths[1], val_filepaths[1], test_filepaths[1]], en_tokenizer)

In [None]:
def data_process(filepaths):
  raw_de_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
    de_tensor_ = torch.tensor([de_vocab[token] for token in de_tokenizer(raw_de)], dtype=torch.long)
    en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en)], dtype=torch.long)
    data.append((de_tensor_, en_tensor_))
  return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

### Creating Dataloaders

In [None]:
BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

def generate_batch(data_batch):
  de_batch, en_batch = [], []
  for (de_item, en_item) in data_batch:
    de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
  de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True)
  en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
  return de_batch, en_batch

training_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch, drop_last=True)
validation_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch, drop_last=True)

## Getting the available device

In [None]:
# Getting the device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

## Defining Various Components of a Transformer

### Positional Encoding

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size:int, d_model:int) -> None:
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, padding_idx=0)
        self.pos_encoding = self.positional_encoding(length=2048, depth=d_model)
    
    
    def forward(self, x:Tensor) -> Tensor:
        length = x.size(dim=1)
        x = self.embedding(x)
        # Scaling
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = x + torch.unsqueeze(self.pos_encoding, dim=0)[:, :length, :]
        return x
    

    def positional_encoding(self, length:int, depth:int) -> Tensor:
        depth = depth/2

        positions = np.arange(length)[:, np.newaxis]
        depths = np.arange(depth)[np.newaxis, :] / depth

        angle_rates = 1 / (10000**depths)
        angle_rads = positions * angle_rates

        pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)], axis=-1).astype("float32")

        return torch.from_numpy(pos_encoding)


### Multi-Headed Attention

#### Attention Head

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, dim_in:int, dim_k:int, dim_v:int) -> None:
        super(AttentionHead, self).__init__()
        self.q = nn.Linear(in_features=dim_in, out_features=dim_k)
        self.k = nn.Linear(in_features=dim_in, out_features=dim_k)
        self.v = nn.Linear(in_features=dim_in, out_features=dim_v)
    

    def forward(self, query:Tensor, key:Tensor, value:Tensor, is_causal:bool) -> Tensor:
        return self.scaled_dot_product_attention(self.q(query), self.k(key), self.v(value), is_causal)  

    
    def scaled_dot_product_attention(self, query:Tensor, key:Tensor, value:Tensor, is_causal:bool) -> Tensor:
        if is_causal:
            mask_size = (query.size(0), query.size(1), query.size(1))
            mask = self.get_causal_mask(mask_size)
            unmasked_attn_weights = query.bmm(key.transpose(dim0=1, dim1=2))
            attn_weights = torch.mul(unmasked_attn_weights, mask)
            del mask
        else:
            attn_weights = query.bmm(key.transpose(dim0=1, dim1=2))

        scale = query.size(-1) ** 0.5
        softmax_vals = F.softmax(attn_weights / scale, dim=-1)        
        return softmax_vals.bmm(value)
    

    def get_causal_mask(self, mask_size:tuple) -> Tensor:
        # Creates a mask for causal self-attention
        mask = np.ones(shape=mask_size)
        for row in range(mask.shape[1]):
            mask[:, row, row+1:] *= 0
        return torch.as_tensor(mask, dtype=torch.float32)

#### Multi-Head Attention (with optional is_causal for causal self-attention)

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads:int, dim_in:int, dim_k:int, dim_v:int, is_causal:bool = False) -> None:
        super(MultiHeadedAttention, self).__init__()
        self.is_causal = is_causal
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_k, dim_v) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)
    

    def forward(self, query:Tensor, key:Tensor, value:Tensor) -> Tensor:
        concat_layer = torch.cat([head(query, key, value, self.is_causal) for head in self.heads], dim=-1)
        return self.linear(concat_layer)

#### Position-wise Feedforward Networks

In [None]:
class Feedforward():
    def __init__(self, d_model:int = 512, d_ff:int = 2048) -> None:
        self.d_model = d_model
        self.d_ff = d_ff
    

    def get_feedforward(self) -> nn.Module:
        return nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.ReLU(),
            nn.Linear(self.d_ff, self.d_model)
        )

#### Residual Connections

In [None]:
class Residual(nn.Module):
    def __init__(self, sublayer:nn.Module, dimension:int, dropout:float = 0.1) -> None:
        super(Residual, self).__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)
    

    def forward(self, *tensors:Tensor) -> Tensor:
        # Assumes query tensor is provided first
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))


### Transformer Encoder

#### Encoder Layer

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, num_heads:int, d_model:int, d_ff:int, dropout:float = 0.1) -> None:
        super(TransformerEncoderLayer, self).__init__()
        dim_k = dim_v = max(d_model // num_heads, 1)

        self.global_self_attention = Residual(
            MultiHeadedAttention(num_heads=num_heads, dim_in=d_model, dim_k=dim_k, dim_v=dim_v),
            dimension=d_model,
            dropout=dropout
        )    

        self.feedforward = Residual(
            Feedforward(d_model=d_model, d_ff=d_ff).get_feedforward(),
            dimension=d_model,
            dropout=dropout
        )
    

    def forward(self, x:Tensor) -> Tensor:
        x = self.global_self_attention(x, x, x)
        x = self.feedforward(x)
        return x

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers:int, d_model:int, num_heads:int, 
                 d_ff:int, vocab_size:int, dropout:float = 0.1) -> None:
        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size, d_model)

        self.enc_layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    num_heads=num_heads,
                    d_model=d_model,
                    d_ff=d_ff,
                    dropout=dropout
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout_layer = nn.Dropout(dropout)
    

    def forward(self, x:Tensor) -> Tensor:
        # 'x' is token-IDs shape (batch_size, seq_length)
        x = self.pos_embedding(x) # Shape (batch_size, seq_length, d_model)

        # Add Dropout
        x = self.dropout_layer(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)
        
        return x # Shape (batch_size, seq_length, d_model)

### Transformer Decoder

#### Decoder Layer

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, num_heads:int, d_model:int, d_ff:int, dropout:float = 0.1) -> None:
        super(TransformerDecoderLayer, self).__init__()

        dim_k = dim_v = max(d_model // num_heads, 1)

        self.causal_self_attention = Residual(
            MultiHeadedAttention(num_heads=num_heads, dim_in=d_model, dim_k=dim_k, dim_v=dim_v, is_causal=True),
            dimension=d_model,
            dropout=dropout
        )  

        self.cross_attention = Residual(
            MultiHeadedAttention(num_heads=num_heads, dim_in=d_model, dim_k=dim_k, dim_v=dim_v),
            dimension=d_model,
            dropout=dropout
        )  

        self.feedforward = Residual(
            Feedforward(d_model=d_model, d_ff=d_ff).get_feedforward(),
            dimension=d_model,
            dropout=dropout
        )
    

    def forward(self, x:Tensor, context:Tensor) -> Tensor:
        x = self.causal_self_attention(x, x, x)
        x = self.cross_attention(x, context, context)
        x = self.feedforward(x)
        return x

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_layers:int, d_model:int, num_heads:int,
                 d_ff:int, vocab_size:int, dropout:float = 0.1) -> None:
        super(TransformerDecoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size, d_model)

        self.dec_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    num_heads=num_heads,
                    d_model=d_model,
                    d_ff=d_ff,
                    dropout=dropout
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout_layer = nn.Dropout(dropout)
    

    def forward(self, x:Tensor, context:Tensor) -> Tensor:
        # 'x' is token-IDs shape (batch_size, seq_length)
        x = self.pos_embedding(x) # Shape (batch_size, seq_length, d_model)
        
        # Add Dropout
        x = self.dropout_layer(x)
        
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, context)
        
        return x # Shape (batch_size, seq_length, d_model)

### The Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_layers:int, d_model:int, num_heads:int, d_ff:int,
                 input_vocab_size:int, target_vocab_size:int, dropout:float = 0.1) -> None:
        super(Transformer, self).__init__()

        self.encoder = TransformerEncoder(
            num_layers=num_layers, d_model=d_model, num_heads=num_heads,
            d_ff=d_ff, vocab_size=input_vocab_size, dropout=dropout
        )

        self.decoder = TransformerDecoder(
            num_layers=num_layers, d_model=d_model, num_heads=num_heads,
            d_ff=d_ff, vocab_size=target_vocab_size, dropout=dropout
        )

        self.final_layer = nn.Linear(d_model, target_vocab_size)
    

    def forward(self, x:Tensor, context:Tensor) -> Tensor:
        context = self.encoder(context)  # (batch_size, context_len, d_model)        
        x = self.decoder(x, context)     # (batch_size, target_len, d_model)
        logits = self.final_layer(x)     # (batch_size, target_len, target_vocab_size)

        return logits

## Model Hyperparameters

In [None]:
# Original Paper
num_layers = 6
d_model = 512
d_ff = 2048
num_heads = 8
dropout = 0.1

# # Reduced
# num_layers = 4
# d_model = 128
# d_ff = 512
# num_heads = 8
# dropout = 0.1

## Model Instantiation and Training

In [None]:
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    dropout=dropout,
    input_vocab_size=len(de_vocab),
    target_vocab_size=len(en_vocab)
)

### Learning Rate Scheduler Function

In [None]:
warmup_steps = 4000

def get_learning_rate(current_epoch:int) -> float:
    if current_epoch == 0:
        return 0
    return (d_model ** -0.5) * min(current_epoch**-0.5, current_epoch * (warmup_steps ** -1.5))

### Training Hyperparameters

In [None]:
epochs = 1
initial_learning_rate = 1   # Gets multiplied by LambdaLR Scheduler

optimizer = torch.optim.Adam(transformer.parameters(), lr=initial_learning_rate, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_learning_rate)

### Loss and Metrics

In [None]:
def masked_loss(pred:Tensor, label:Tensor):
    mask = label != 0
    loss_object = nn.CrossEntropyLoss(reduction="none")
    loss = loss_object(label, pred)

    mask = mask.to(dtype=loss.dtype)
    loss *= mask

    loss = torch.sum(loss) / torch.sum(mask)
    return loss


def masked_accuracy(pred:Tensor, label:Tensor):
    pred = torch.argmax(pred, dim=2)
    label = label.to(dtype=pred.dtype)
    matching = label == pred

    mask = label != 0

    matching = torch.logical_and(matching, mask)

    matching = matching.to(dtype=torch.float32)
    mask = mask.to(dtype=torch.float32)

    return torch.sum(matching) / torch.sum(mask)

### Training and Validation Functions

In [None]:
def train(dataloader:DataLoader, model:Transformer, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    for batch, (context,x) in enumerate(dataloader):
        context, x = context.to(device), x.to(device)
        
        # Computing prediction error
        pred = model(x, context)
        loss = loss_fn(pred.transpose(dim0=1, dim1=2), x)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 10 == 0:
            loss, current = loss.item(), (batch+1) * len(context)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]", end="\r")
    
    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def validate(dataloader, model, loss_fn, accuracy_metric):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0

    with torch.no_grad():
        for context, x in dataloader:
            context, x = context.to(device), x.to(device)
            pred = model(x, context)
            test_loss += loss_fn(pred.transpose(dim0=1, dim1=2), x).item()
            correct += accuracy_metric(pred, x).item()
            # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return test_loss, correct

### Model Training

In [None]:
num_parameters = sum(p.numel() for p in transformer.parameters())
print(f"Total number of parameters = {num_parameters}")

PAD_IDX = en_vocab.get_stoi()['<pad>']
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

curr_loss, curr_acc = 1_000_000_000, 0
save_on_loss = True   # if True, saves model when loss reduces, otherwise saves model when accuracy increases
weights_path = "transformer_weights.pth"

for t in range(epochs):
    print(f"Epoch {t+1} \n ----------------------------")
    train(training_dataloader, transformer, loss_fn, optimizer)
    loss, acc = validate(validation_dataloader, transformer, loss_fn, masked_accuracy)

    # Model Save Checkpoints
    if save_on_loss:
        if (loss < curr_loss):
            curr_loss = loss
            curr_acc = acc
            torch.save(transformer.state_dict(), weights_path)
    else:
        if (acc > curr_acc):
            curr_acc = acc
            curr_loss = loss
            torch.save(transformer.state_dict(), weights_path)

    scheduler.step()

print("-------------Training Complete------------------")
print(f"Saved Model's Loss = {curr_loss}")
print(f"Saved Model's Accuracy = {100*curr_acc}%")

## The Translator

In [None]:
class Translator():
    def __init__(self, model:nn.Module, model_weights, input_language_tokenizer, output_language_tokenizer, 
                 input_language_vocab, output_language_vocab, max_tokens=200) -> None:
        self.model = model
        model.load_state_dict(torch.load(model_weights, map_location=device))
        self.input_language_tokenizer = input_language_tokenizer
        self.output_language_tokenizer = output_language_tokenizer
        self.input_language_vocab = input_language_vocab
        self.output_language_vocab = output_language_vocab
        self.max_tokens = max_tokens
        

    def translate(self, input_sentence:str) -> str:
        print(f"Input Sentence: {input_sentence}")
        token_list = self.input_language_tokenizer(input_sentence)
        bos_val = self.output_language_vocab.get_stoi()['<bos>']

        output_sentence = ""
        output_token_id_list = [bos_val]
        last_predicted = bos_val

        eos_val = self.output_language_vocab.get_stoi()['<eos>']

        context_tensor = torch.tensor([self.input_language_vocab[token] for token in token_list], dtype=torch.long).unsqueeze(0)
        print(context_tensor)

        for i in range(self.max_tokens):
            x_tensor = torch.tensor([token_id for token_id in output_token_id_list], dtype=torch.long).unsqueeze(0)
            with torch.no_grad():
                predicted_val = self.model(x_tensor, context_tensor)

            last_predicted = predicted_val[0,-1,:].argmax()
            del x_tensor
            
            if last_predicted == eos_val:
                break
            else:
                output_token_id_list.append(last_predicted)
                output_sentence += " " + self.output_language_vocab.lookup_token(last_predicted)

        return output_sentence

In [None]:
translator = Translator(model=transformer, model_weights=weights_path, 
                        input_language_tokenizer=de_tokenizer, output_language_tokenizer=en_tokenizer, 
                        input_language_vocab=de_vocab, output_language_vocab=en_vocab)

with io.open(test_filepaths[0], encoding="utf8") as f:
    for string_ in f:
        translated_sentence = translator.translate(string_)
        print(translated_sentence)
        break