<a href="https://colab.research.google.com/github/ymoslem/PyTorchNLP/blob/main/Ex3-NMT-Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NMT with Attention

* **Paper:** <a href="https://arxiv.org/pdf/1409.0473.pdf">Neural Machine Translation by Jointly Learning to Align and Translate</a>

* **Method:** Extending the encoder–decoder architecture by allowing a model to automatically (soft-)search for parts of a source sentence that are relevant to predicting a target word, without having to form these parts as a hard segment explicitly.


In [None]:
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils import clip_grad_norm_

from torchtext.vocab import build_vocab_from_iterator

import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import unicodedata
import re
import time
import random
import spacy

from utils import translate, load_checkpoint, save_checkpoint, load_checkpoint_for_inference

print(torch.__version__)  # 1.11.0+cu113


In [None]:
# Load the Multi30k German-to-English dataset
# Info: https://pytorch.org/text/stable/datasets.html#multi30k

from torchtext.datasets import Multi30k
train_iter, valid_iter, test_iter = Multi30k()

In [None]:
# Read the first sentence
src_sentence, tgt_sentence = next(iter(train_iter))
print(src_sentence, tgt_sentence, sep="\n")

In [None]:
# Number of segments
count = 0
for item in train_iter:
    count +=1
print(count)

In [None]:
#!python3 -m spacy download de_core_news_sm
#!python3 -m spacy download en_core_web_sm

spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

def tokenizer_de(text):
    tokenized_text = [tok.text for tok in spacy_de.tokenizer(text)]
    return tokenized_text

def tokenizer_en(text):
    tokenized_text = [tok.text for tok in spacy_en.tokenizer(text)]
    return tokenized_text

In [None]:
tokenizer_en("here is a test")

In [None]:
# Build Vocabulary
# Info: https://pytorch.org/text/stable/vocab.html?highlight=build%20vocab#torchtext.vocab.build_vocab_from_iterator

def yield_tokens(train_iter, direction):
    for source, target in train_iter:
        if direction == "source":
            source_tokenized = tokenizer_de(source)
            yield source_tokenized
        elif direction == "target":
            target_tokenized = tokenizer_en(target)
            yield target_tokenized
        else:
            raise ValueError("direction should 'source' or 'target'")


source_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "source"),
                                     specials=["<unk>", '<pad>', "<s>", "</s>"],
                                     min_freq=2,
                                     max_tokens=50000)
source_vocab.set_default_index(source_vocab["<unk>"])

target_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "target"),
                                     specials=["<unk>", '<pad>', "<s>", "</s>"],
                                     min_freq=2,
                                     max_tokens=50000)
target_vocab.set_default_index(target_vocab["<unk>"])

In [None]:
print(len(source_vocab), len(target_vocab))

In [None]:
target_vocab(['<s>', 'here', 'is', 'an', 'example', '</s>'])

In [None]:
# Info: https://colab.research.google.com/github/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
pad_idx = target_vocab["<pad>"]


def collate_batch(batch):
    sources, targets = [], []
    for source, target in batch:
        source = ["<s>"] + tokenizer_de(source.lower()) + ["</s>"]
        target = ["<s>"] + tokenizer_en(target.lower()) + ["</s>"]
        
        source_idx = source_vocab(source)
        target_idx = target_vocab(target)
        
        source_tensor = torch.tensor(source_idx, dtype=torch.int64)
        target_tensor = torch.tensor(target_idx, dtype=torch.int64)
        
        sources.append(source_tensor)
        targets.append(target_tensor)
        
    sources = pad_sequence(sources, padding_value=pad_idx)
    sources = sources.to(device)
    
    targets = pad_sequence(targets, padding_value=pad_idx)
    targets = targets.to(device)
    
    return sources, targets

train_dataloader = DataLoader(train_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(valid_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)


In [None]:
# Check first item in the dataloader
# print(*next(iter(train_dataloader)), sep="\n\t")

In [None]:
# Example usage
# for x_data, y_data in train_dataloader:
#    x_data, y_data = x_data.to(device), y_data.to(device)

# Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, dropout_p):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout(dropout_p)
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, bidirectional=True)  # removed: dropout=dropout_p
        
        self.fc_hidden = nn.Linear(hidden_size*2, hidden_size)
        self.fc_cell = nn.Linear(hidden_size*2, hidden_size)
    
    def forward(self, x):
        # x shape: (seq_length, N) where N is the batch_size
        
        embedding = self.dropout(self.embedding(x))
        # embedding shape: (seq_length, N, embedding_size)
        
        encoder_states, (hidden, cell) = self.rnn(embedding)
        
        hidden = self.fc_hidden(torch.cat((hidden[0:1], hidden[1:2]), dim=2))
        cell = self.fc_cell(torch.cat((cell[0:1], cell[1:2]), dim=2))
        
        # retrun the context vector
        return encoder_states, hidden, cell


# Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers, dropout_p):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout(dropout_p)
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.LSTM(hidden_size*2 + embedding_size, hidden_size, num_layers)  # removed: dropout=dropout_p
        
        self.energy = nn.Linear(hidden_size*3, 1)
        self.softmax = nn.Softmax(dim=0)
        self.relu = nn.ReLU()
        
        self.fc = nn.Linear(hidden_size, vocab_size)
        # The output has vocab_size because it includes the probability of each token in the target vocab 
        
    def forward(self, x, encoder_states, hidden, cell):
        # x shape: (N), but we need (1, N)
        x = x.unsqueeze(0)
        
        embedding = self.dropout(self.embedding(x))
        # embedding shape: (1, N, embedding_size)
        
        # --- Attention ---
        sequence_len = encoder_states.shape[0]
        hidden_reshaped = hidden.repeat(sequence_len, 1, 1)
        # hidden_reshaped shape: (seq_length, N, hidden_size)
        
        energy = self.relu(self.energy(torch.cat((hidden_reshaped, encoder_states), dim=2)))
        attention = self.softmax(energy)
        # attention shape: (sequence_len, N, 1)
        attention = attention.permute(1, 2, 0)
        # attention shape (reordered): (N, 1, sequence_len)
        encoder_states = encoder_states.permute(1, 0, 2)
        # original encoder_states shape: (seq_length, N, hidden_size*2)
        # current encoder_states shape: (N, seq_length, hidden_size*2)
        
        # (N, 1, hidden_size*2) --> (1, N, hidden_size*2)
        context_vector = torch.bmm(attention, encoder_states).permute(1, 0, 2)
        
        rnn_input = torch.cat((context_vector, embedding), dim=2)
        
        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        # outputs shape: (1, N, hidden_size)
        
        predictions = self.fc(outputs)
        # predictions shape: (1, N, vocab_size)
        
        predictions = predictions.squeeze(0)
        # predictions shape: (N, vocab_size)
        
        return predictions, hidden, cell
        

# Seq2Seq  
Combinging the Encoder and Decoder in the Seq2Seq model

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, source, target, teacher_force_ratio=0.5):
        batch_size = source.shape[1]
        target_len = target.shape[0]
        target_vocab_size = len(target_vocab)
        
        # Probabilities are added to the 3rd dimention whose size is target_vocab_size
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)
        
        # Get hidden and cell from the encoder to be the input of the decoder
        encoder_states, hidden, cell = self.encoder(source)
        
        # Grab the start token
        x = target[0]
        
        # Send x as well as hidden and cell from the encoder to the decoder
        # The output will be the next hidden and cell
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(x, encoder_states, hidden, cell)
            
            # Modify "outputs" with the current "output
            # output shape: batch_size, target_vocab_size
            # Probabilities are added to the 2nd dimention whose size is target_vocab_size
            outputs[t] = output
            
            # Get the highest probability from the 2nd dimintion
            best_guess = output.argmax(1)
            
            # During training, sometimes the next input to the decoder will be the real target token;
            # sometimes will be the predicted target token, if a random value > teacher_force_ratio
            x = target[t] if random.random() < teacher_force_ratio else best_guess
            
        return outputs

# Helper Functions

In [None]:
import torch
import spacy
from torchtext.data.metrics import bleu_score
import sys
from random import random


def translate(text, model, tokenizer, source_vocab, target_vocab, device, max_length=50):
    
    # Tokenize the text and lower-case it
    tokenized_text = tokenizer(text)
    tokenized_text = ["<s>"] + [token.lower() for token in tokenized_text ] + ["</s>"]
    # print(tokenized_text)

    # Convert text to indices
    text_to_indices = source_vocab(tokenized_text)

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    
    model.eval()
    
    outputs = target_vocab(["<s>"])
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == target_vocab["</s>"]:
            break
    
    target_vocab_itos = target_vocab.get_itos()
    translated_sentence = [target_vocab_itos[idx] for idx in outputs]
    # remove start token
    translated_sentence = translated_sentence[1:]
    translated_sentence = " ".join(translated_sentence)
    
    return translated_sentence


def bleu(data_iter, model, tokenizer, source_vocab, target_vocab, device):
    targets = []
    outputs = []

    for source, target in data_iter:

        prediction = translate(source, model, tokenizer, source_vocab, target_vocab, device)
        prediction = prediction[:-1]  # remove the start <s> token

        targets.append([target])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["opt"])


def load_checkpoint_for_inference(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    print("Model checkpoint loaded")

# Training Setup

In [None]:
# Training Hyperparameters
num_epochs = 100
learning_rate = 3e-4
batch_size = 32  # make sure it is the same as in data preperation

# Model Hyperparameters
load_model = False
device = torch.device("cuda" if torch.cuda.is_available() == True else "cpu")
source_vocab_size = len(source_vocab)  # input size of the encoder
target_vocab_size = len(target_vocab)  # input size and output size of the decoder
embedding_size = 256
hidden_size = 1024
num_layers = 1
dropout = 0.0

# Tensorboard
writer = SummaryWriter(f"runs/loss_plot")
step = 0

encoder_network = Encoder(source_vocab_size, embedding_size, hidden_size, num_layers, dropout).to(device)
decoder_network = Decoder(target_vocab_size, embedding_size, hidden_size, num_layers, dropout).to(device)

model = Seq2Seq(encoder_network, decoder_network).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

pad_idx = target_vocab["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

if load_model:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

# Training Loop

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch [{epoch} / {num_epochs}]")
    
    checkpoint = {"state_dict":model.state_dict(),
                  "opt":optimizer.state_dict(),
                  "encoder_type":"lstm"}
    save_checkpoint(checkpoint)
    
    
    # important if model.eval() was called earlier
    model.train()
    
    for source_batch, target_batch in train_dataloader:
        source = source_batch.to(device)
        target = target_batch.to(device)
        
        output = model(source, target)
        # output shape: (target_len, batch_size, output_dim)
                
        # Exclude the start token: output[1:]
        # Reshape to match the accepted input form of CrossEntropyLoss
        # Keep the output dimention (whose size is vocab_size) and flatten the two first dimentions
        output = output[1:].reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)
        
        optimizer.zero_grad()
        loss = criterion(output, target)
        
        loss.backward()
        
        # Clip to avoid exploding gradients
        clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        
        writer.add_scalar("Training Loss", loss, global_step=step)
        step += 1


In [None]:
src_test_sentence, tgt_test_sentence = next(iter(test_iter))
print(src_test_sentence, tgt_test_sentence, sep="\n")

In [None]:
checkpoint_path = "my_checkpoint.pth.tar"
load_checkpoint_for_inference(model, checkpoint_path)

In [None]:
sentence = "Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt."
# sentence = "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche."
translate(sentence, model, tokenizer_de, source_vocab, target_vocab, device, max_length=50)