In [1]:
import torch
import torch.nn as nn
import numpy as np
import random
import time
import sys
from torch.utils.tensorboard import SummaryWriter
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from torchtext.data.metrics import bleu_score

In [2]:
def tokenize(text):
    p = np.array([])
    for s in text.split(','):
        s1 = s.split(' ')
        for s2 in s1:
            s2 = s2.split('.')
            if s2!=['']:
                p = np.append(p, s2)
    if p[-1]=='':
        return list(p[:-1])
    else:
        return list(p)

In [3]:
english = Field(
    sequential=True, 
    use_vocab=True, 
    tokenize=tokenize, 
    lower=True,
    init_token='<sos>',
    eos_token='<eos>'
)
german = Field(
    sequential=True, 
    use_vocab=True, 
    tokenize=tokenize, 
    lower=True,
    init_token='<sos>',
    eos_token='<eos>'
)

In [4]:
train_data, val_data, test_data = Multi30k.splits(
    exts = ('.de', '.en'), # (Source language, Target Language)
    fields = (german, english) # And then map the source and target to the respective variables
)

In [5]:
# Build a vocabulary
english.build_vocab(train_data, max_size = 10000, min_freq = 2) # We won't add words used ONLY once, should occur atleast twice
german.build_vocab(train_data, max_size = 10000, min_freq = 2) # We won't add words used ONLY once

In [6]:
batch_size = 512
device = torch.device('cpu')

In [7]:
train_iterator, val_iterator, test_iterator = BucketIterator.splits(
    (train_data, val_data, test_data),
    batch_sizes=(batch_size, batch_size, batch_size),
    sort_within_batch = True, # The batches are formed based on the length of the sentences
    sort_key = lambda x: len(x.src), # This would prioritise the similar length sentences in the batch
    device = device
)

# Transformer

In [8]:
# To understand some functions used in masking
ss1 = (
    torch.arange(0, 5).unsqueeze(1) # (5, 1) --> Added one dimension in index 1
)
ss2 = (
    torch.arange(0, 5).unsqueeze(1).expand(5, 10)
) # For 10 sentences
ss1, ss2

(tensor([[0],
         [1],
         [2],
         [3],
         [4]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))

In [9]:
class Transformer(nn.Module):
    def __init__(
        self,
        source_vocab_size,
        target_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        embedding_size,
        forward_expansion,
        dropout_prob,
        max_length,
        device
        ) -> None:
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(source_vocab_size, embedding_size, device=device)
        self.src_position_embedding = nn.Embedding(max_length, embedding_size, device=device)
        self.trg_word_embedding = nn.Embedding(target_vocab_size, embedding_size, device=device)
        self.trg_position_embedding = nn.Embedding(max_length, embedding_size, device=device)

        self.device = device

        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            dim_feedforward = forward_expansion*embedding_size,
            dropout = dropout_prob,
            device = self.device
        )

        self.fc_out = nn.Linear(embedding_size, target_vocab_size, device = device)
        self.dropout = nn.Dropout(dropout_prob)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

    def make_src_mask(self, src):
        # src shape: (src_len, batch_size)
        # But Pytorch takes (batch_size, src_len)
        src_mask = (src.transpose(0,1) == self.src_pad_idx).to(self.device)
        # src_mask shape: (batch_size, src_len)
        return src_mask
    
    def forward(self, src, tgt):
        src_seq_len, batch_size = src.shape
        trg_seq_len, batch_size = tgt.shape

        # Create positions for Position Embeddings
        src_positions = (
            torch.arange(0, src_seq_len)
            .unsqueeze(1)
            .expand(src_seq_len, batch_size)
            .to(self.device)
        )
        # All positions of words would be labelled as 0, 1, 2,...src_length
        # When inputted into a trainable position embedding, the value would represent a certain vector for that position, which will be added to the word_emebdding

        trg_positions = (
            torch.arange(0, trg_seq_len)
            .unsqueeze(1)
            .expand(trg_seq_len, batch_size)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions)) 
        ) # output: (sequence_length, batch_size, embedding_dimension)

        embed_trg = self.dropout(
            (self.trg_word_embedding(tgt) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src) # (batch_size, sequence_length)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask = src_padding_mask,
            tgt_mask = trg_mask
        )

        out = self.fc_out(out)
        
        return out

## Setting up the training phase
### Hyperparameters

In [10]:
load_model = False
save_model = True

In [11]:
num_epochs = 50
learning_rate = 3e-4

In [1]:
src_vocab_size = len(german.vocab)
trg_vocab_size = len(english.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout_prob = 0.1
max_length = 100 # For max length of positional embedding
forward_expansion = 4

NameError: name 'german' is not defined

In [13]:
src_pad_idx = german.vocab.stoi["<pad>"]
trg_pad_idx = english.vocab.stoi["<pad>"]

In [14]:
writer = SummaryWriter("runs/Transformer_Loss_Plot/")
step = 0

### Model

In [15]:
transformer = Transformer(
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    trg_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    embedding_size,
    forward_expansion,
    dropout_prob,
    max_length,
    device
)

In [16]:
optimizer = torch.optim.Adam(transformer.parameters(), lr = learning_rate)

In [17]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    factor=0.1, 
    patience=10, 
    verbose=True
)

In [18]:
criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)

### Utility functions

In [19]:
def save_checkpoint(state, filename="models_state_dict/Seq2Seq_Transformer_checkpoint.pth.tar"):
    print("Saving Checkpoint...")
    torch.save(state, filename)
    print("Saved!")

In [20]:
def load_checkpoint(checkpoint, model, optimizer):
    print("Loading checkpoint...")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    print("Successfully loaded!")

In [21]:
def translate_sentence(model, sentence, source, target, device, max_length=50):
    # Load source tokenizer
    if type(sentence) == str:
        tokens = [token.lower() for token in tokenize(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, source.init_token)
    tokens.append(source.eos_token)

    # Go through each source token and convert to an index
    text_to_indices = [source.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [target.vocab.stoi["<sos>"]]
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == target.vocab.stoi["<eos>"]:
            break

    translated_sentence = [target.vocab.itos[idx] for idx in outputs]
    # remove start token
    return translated_sentence[1:]

In [22]:
def bleu(data, model, source, target, device):
    targets = []
    outputs = []

    for example in data:
        src = vars(example)["src"]
        trg = vars(example)["trg"]

        prediction = translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1] # Removing <eos> token

        targets.append([trg])
        outputs.append(prediction)
    
    return bleu_score(outputs, targets)

### Training

In [23]:
with torch.no_grad():
    total_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad) # Number of Parameters
    print(f'Total number of trainable parameters = {total_params}')

Total number of trainable parameters = 54352204


In [24]:
if load_model:
    load_checkpoint(torch.load("models_state_dict/Seq2Seq_Transformer_checkpoint.pth.tar"), transformer, optimizer)

In [25]:
sentence = "ein boot mit mehreren männern darauf wird von einem großen pferdegespann ans ufer gezogen."
max1 = 0

In [26]:
print(src_vocab_size)
print(trg_vocab_size)

7805
5964


In [27]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    losses = []

    if save_model:
        checkpoint = {
            "state_dict": transformer.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        save_checkpoint(checkpoint)
    
    transformer.eval() # Will turn off dropout
    translated_sentence = translate_sentence(transformer, sentence, german, english, device, max_length=100)
    translated_sentence_final = ''
    for i, word in enumerate(translated_sentence[:-1]):
        if i != len(translated_sentence)-1:
            translated_sentence_final+=word+' '
        else:
            translated_sentence_final+=word+'.'
    print(f"Translated Example Sentence: \n {translated_sentence_final}")

    transformer.train()

    tic = time.time()
    for batch_idx, batch in enumerate(train_iterator):
        input_data = batch.src.to(device) # (sequence_length, batch_size)
        target = batch.trg.to(device) # (sequence_length, batch_size)

        # Forward Propagation
        output = transformer(input_data, target[:-1, :]) # The output must be shifted by ONE step right - so the <eos> token will be removed
        # output shape: (sequence_length, batch_size, target_vocab_size)
        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
        # doesn't take input in that form. For example if we have MNIST we want to have
        # output to be: (N, 10) and targets just (N). Here we can view it in a similar
        # way that we have output_words * batch_size that we want to send in into
        # our cost function, so we need to do some reshaping.

        output = output.reshape(-1, output.shape[2])
        target = target[1:].reshape(-1) # Now we will remove the first word
        # So the input is the FIRST word to the SECOND LAST word
        # The output we want to compare it with is SECOND word to LAST word
        # We need to combine the batch with the trg_length to put into the loss function
        loss = criterion(output, target)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(transformer.parameters(), max_norm = 1)
        optimizer.step()

        writer.add_scalar("Training Loss", loss, global_step = step)
        step+=1
    
    mean_loss = sum(losses) / len(losses)
    # scheduler.step(mean_loss)
    
    print(f"Time taken: {(time.time() - tic)//60:.0f}m {(time.time() - tic)%60:.0f}s")

Epoch 1/50
Saving Checkpoint...
Saved!
Translated Example Sentence: 
 sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks tuba sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks sparks 
Time taken: 3m 45s
Epoch 2/50
Saving Checkpoint...
Saved!
Translated Example Sentence: 
 a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a

## Testing

In [None]:
if load_model:
    load_checkpoint(torch.load("models_state_dict/Seq2Seq_Attention_checkpoint.pth.tar"), transformer, optimizer)
transformer.eval()

In [None]:
score = bleu(test_data, transformer, german, english, device)
print(f"Bleu Score for test data = {score*100:.2f}")

### Trials

In [None]:
sentence = "Es gibt so viele verschiedene Möglichkeiten für Eiscreme"

**Expected translation**: *There is so much variety in the options for icecream available*

In [None]:
with torch.no_grad():
    translated_sentence = translate_sentence(transformer, sentence, german, english, device, max_length=100)
    translated_sentence_final = ''
    for i, word in enumerate(translated_sentence[:-1]):
        if i != len(translated_sentence)-1:
            translated_sentence_final+=word+' '
        else:
            translated_sentence_final+=word+'.'
    print(f"Translated Example Sentence: \n {translated_sentence_final}")