<a href="https://colab.research.google.com/github/ymoslem/PyTorchNLP/blob/main/Ex5-NMT-Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NMT with PyTorch nn.Transformer

* **Paper:** <a href="https://arxiv.org/abs/1706.03762">Attention is all you need</a>

* **PyTorch Transformer Classs:** https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import spacy

print(torch.__version__)  # 1.11.0+cu113

In [None]:
# Load the Multi30k German-to-English dataset
# Info: https://pytorch.org/text/stable/datasets.html#multi30k

from torchtext.datasets import Multi30k
train_iter, valid_iter, test_iter = Multi30k()

In [None]:
# Read the first sentence
src_sentence, tgt_sentence = next(iter(train_iter))
print(src_sentence, tgt_sentence, sep="\n")

In [None]:
# Number of segments
count = 0
for item in train_iter:
    count +=1
print(count)

In [None]:
#!python3 -m spacy download de_core_news_sm
#!python3 -m spacy download en_core_web_sm

spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

def tokenizer_de(text):
    tokenized_text = [tok.text for tok in spacy_de.tokenizer(text)]
    return tokenized_text

def tokenizer_en(text):
    tokenized_text = [tok.text for tok in spacy_en.tokenizer(text)]
    return tokenized_text

In [None]:
tokenizer_en("here is a test")

In [None]:
# Build Vocabulary
# Info: https://pytorch.org/text/stable/vocab.html?highlight=build%20vocab#torchtext.vocab.build_vocab_from_iterator

def yield_tokens(train_iter, direction):
    for source, target in train_iter:
        if direction == "source":
            source_tokenized = tokenizer_de(source)
            yield source_tokenized
        elif direction == "target":
            target_tokenized = tokenizer_en(target)
            yield target_tokenized
        else:
            raise ValueError("direction should 'source' or 'target'")


source_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "source"),
                                     specials=["<unk>", '<pad>', "<s>", "</s>"],
                                     min_freq=1,
                                     max_tokens=50000)
source_vocab.set_default_index(source_vocab["<unk>"])

target_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "target"),
                                     specials=["<unk>", '<pad>', "<s>", "</s>"],
                                     min_freq=1,
                                     max_tokens=50000)
target_vocab.set_default_index(target_vocab["<unk>"])

In [None]:
print(len(source_vocab), len(target_vocab))

In [None]:
target_vocab(['<s>', 'here', 'is', 'an', 'example', '</s>'])

In [None]:
# Info: https://colab.research.google.com/github/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256  # examples
pad_idx = target_vocab["<pad>"]


def collate_batch(batch):
    sources, targets = [], []
    for source, target in batch:
        source = ["<s>"] + tokenizer_de(source.lower()) + ["</s>"]
        target = ["<s>"] + tokenizer_en(target.lower()) + ["</s>"]
        
        source_idx = source_vocab(source)
        target_idx = target_vocab(target)
        
        source_tensor = torch.tensor(source_idx, dtype=torch.int64)
        target_tensor = torch.tensor(target_idx, dtype=torch.int64)
        
        sources.append(source_tensor)
        targets.append(target_tensor)
        
    sources = pad_sequence(sources, padding_value=pad_idx)
    sources = sources.to(device)
    
    targets = pad_sequence(targets, padding_value=pad_idx)
    targets = targets.to(device)
    
    return sources, targets


# [To-Do] Add batch_sampler to act as a bucket iterator

train_dataloader = DataLoader(train_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

In [None]:
# Check first item in the dataloader
# print(*next(iter(train_dataloader)), sep="\n\t")

In [None]:
# Example usage
# for x_data, y_data in train_dataloader:
#    x_data, y_data = x_data.to(device), y_data.to(device)

# Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        tgt_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expantion,
        dropout,
        max_len,
        device
    ):
        super(Transformer, self).__init__()
        
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_positional_embedding = nn.Embedding(max_len, embedding_size)
        
        self.tgt_word_embedding = nn.Embedding(tgt_vocab_size, embedding_size)
        self.tgt_positional_embedding = nn.Embedding(max_len, embedding_size)
        
        self.device = device
        self.src_pad_idx = src_pad_idx
        self.dropout = nn.Dropout(dropout)
        
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expantion,
            dropout
        )
        
        self.fc_out = nn.Linear(embedding_size, tgt_vocab_size)
        
    def make_src_mask(self, src):
        # src shape: (src_len, N)
        # src_mask shape: (N, src_len) 
        # maching required shape of src_key_padding_mask in nn.Transformer
        src_mask = src.transpose(0, 1) == self.src_pad_idx
        # src_mask shape: (N, src_len)
        
        return src_mask.to(self.device)

    def forward(self, src, tgt):
        src_seq_length, N = src.shape
        tgt_seq_length, N = tgt.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        tgt_positions = (
            torch.arange(0, tgt_seq_length)
            .unsqueeze(1)
            .expand(tgt_seq_length, N)
            .to(self.device)
        )

        src_embedding = self.dropout(
            (self.src_word_embedding(src) + self.src_positional_embedding(src_positions))
        )

        tgt_embedding = self.dropout(
            (self.tgt_word_embedding(tgt) + self.tgt_positional_embedding(tgt_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_length).to(self.device)

        out = self.transformer(
            src_embedding,
            tgt_embedding,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=tgt_mask
        )

        out = self.fc_out(out)

        return out

# Helper Functions

In [None]:
import torch
import spacy
from torchtext.data.metrics import bleu_score
import sys
from random import random


def translate(text, model, tokenizer, source_vocab, target_vocab, device, max_length=50):
    
    # Tokenize the text and lower-case it
    tokenized_text = tokenizer(text)
    tokenized_text = ["<s>"] + [token.lower() for token in tokenized_text ] + ["</s>"]
    # print(tokenized_text)

    # Convert text to indices
    text_to_indices = source_vocab(tokenized_text)

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    
    model.eval()
    
    outputs = target_vocab(["<s>"])
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == target_vocab["</s>"]:
            break
    
    target_vocab_itos = target_vocab.get_itos()
    translated_sentence = [target_vocab_itos[idx] for idx in outputs]
    # remove start token
    translated_sentence = translated_sentence[1:]
    translated_sentence = " ".join(translated_sentence)
    
    return translated_sentence


def bleu(data_iter, model, tokenizer, source_vocab, target_vocab, device):
    targets = []
    outputs = []

    for source, target in data_iter:

        prediction = translate(source, model, tokenizer, source_vocab, target_vocab, device)
        prediction = prediction[:-1]  # remove the start <s> token

        targets.append([target])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["opt"])


def load_checkpoint_for_inference(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    print("Model checkpoint loaded")

# Training Setup

In [None]:
load_model = False
save_model = True

# Training Hyperparameters
num_epochs = 100
learning_rate = 3e-4
batch_size = 256  # examples - make sure it is the same as in data preperation

# Model Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() == True else "cpu")
src_vocab_size = len(source_vocab)
tgt_vocab_size = len(target_vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3  # 6
num_decoder_layers = 3  # 6
dropout = 0.1
max_len = 100
forward_expansion = 2048
src_pad_idx = source_vocab["<pad>"]


# Tensorboard
writer = SummaryWriter(f"runs/loss_plot")
step = 0

model = Transformer(
    embedding_size,
    src_vocab_size,
    tgt_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)
    

tgt_pad_idx = target_vocab["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)

if load_model:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

# Training Loop

In [None]:
# "Butter and cheese are made from milk."
# sentence = "Butter und Käse werden aus Milch gemacht."
# "A horse walks under a bridge next to a boat."
sentence = "ein pferd geht unter einer brücke neben einem boot."

for epoch in range(num_epochs):
    print(f"Epoch [{epoch} / {num_epochs}]")
    
    # Save checkpoint
    if save_model:
        checkpoint = {"state_dict":model.state_dict(),
                      "opt":optimizer.state_dict(),
                      "encoder_type":"transformer"
                     }
        filename="my_checkpoint.pth.tar"
        torch.save(checkpoint, filename)
    
    
    # Translate the example sentence
    translated_sentence = translate(sentence, model, tokenizer_de, source_vocab, target_vocab, device, max_length=50)
    print(f"Translated example:\n {translated_sentence}")
    
    # important if model.eval() was called earlier as in translate()
    model.train()
    
    losses = []
    
    for source_batch, target_batch in train_dataloader:
        source = source_batch.to(device)
        target = target_batch.to(device)
        
        # Forward propagation
        output = model(source, target[:-1, :])
        # output shape: (target_len, batch_size, output_dim)
                
        # Exclude the start token
        # Reshape to match the accepted input form of CrossEntropyLoss
        # Keep the output dimention (whose size is tgt_vocab_size, for the probability of each token)...
        # and flatten the two first dimentions
        output = output.reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)
        
        optimizer.zero_grad()
        loss = criterion(output, target)
        losses.append(loss.item())
        
        # Back propagation
        loss.backward()
        
        # Clip to avoid exploding gradients, makes sure grads are within a healthy range
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        
        # Gradient descent step
        optimizer.step()
        
        writer.add_scalar("Training Loss", loss, global_step=step)
        step += 1

    mean_loss = sum(losses) / len(losses)
    scheduler.step(mean_loss)

# Evaluate

In [None]:
src_test_sentence, tgt_test_sentence = next(iter(test_iter))
print(src_test_sentence, tgt_test_sentence, sep="\n")

In [None]:
checkpoint_path = "my_checkpoint.pth.tar"
load_checkpoint_for_inference(model, checkpoint_path)

In [None]:
sentence = "Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt."
# sentence = "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche."
translate(sentence, model, tokenizer_de, source_vocab, target_vocab, device, max_length=50)

In [None]:
# Calculate BLEU of the test     
score = bleu(test_iter, model, tokenizer_de, source_vocab, target_vocab, device)
print(f"BLEU score: {score*100:.f2}")