Reference: https://www.youtube.com/watch?v=M6adRGJe5cQ

In [1]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import spacy
from utils import translate_sentence, bleu, save_checkpoint, load_checkpoint
from torch.utils.tensorboard import SummaryWriter
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

In [2]:
spacy_ger = spacy.load("de_core_news_sm")
spacy_eng = spacy.load("en_core_web_sm")

In [3]:
def tokenize_ger(text):
    return [tok.text for tok in spacy_ger.tokenizer(text)]

def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]

In [4]:
german = Field(tokenize=tokenize_ger, lower=True, init_token="<sos>", eos_token="<eos>")
english = Field(tokenize=tokenize_eng, lower=True, init_token="<sos>", eos_token="<eos>")

In [5]:
train_data, valid_data, test_data = Multi30k.splits(
    exts=(".de", ".en"), fields=(german, english)
)

In [6]:
german.build_vocab(train_data, max_size=10000, min_freq=2)
english.build_vocab(train_data, max_size=10000, min_freq=2)

In [7]:
class Transformer(nn.Module):
    def __init__(self,
                 embedding_size,
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 forward_expansion,
                 dropout,
                 max_len,
                 device,
                ):
        super(Transformer, self).__init__()
        
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)     
        self.device = device
        
        self.transformer = nn.Transformer(embedding_size,
                                         num_heads,
                                         num_encoder_layers,
                                         num_decoder_layers,
                                         forward_expansion,
                                         dropout)
        
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
        self.src_pad_idx = src_pad_idx
        
    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx
        return src_mask.to(self.device)
    
    def forward(self, src, trg):
        src_seq_len, N = src.shape
        trg_seq_len, N = trg.shape
        
        src_positions = (
            torch.arange(0, src_seq_len).unsqueeze(1).expand(src_seq_len, N).to(self.device)
        )
        
        trg_positions = (
            torch.arange(0, trg_seq_len).unsqueeze(1).expand(trg_seq_len, N).to(self.device)
        )
        
        embd_src = self.dropout(
            self.src_word_embedding(src) + self.src_position_embedding(src_positions)
        )
        
        embd_trg = self.dropout(
            self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions)
        )
        
        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)
        
        out = self.transformer(embd_src,
                              embd_trg,
                              src_key_padding_mask=src_padding_mask,
                              tgt_mask=trg_mask)
        
        out = self.fc_out(out)
        
        return out

In [8]:
# Set up the training phase
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

load_model = False
save_model = True


# Training hyperparameters

num_epochs = 5
learning_rate = 3e-5
batch_size = 32

src_vocab_size = len(german.vocab)
trg_vocab_size = len(english.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3

dropout = 0.10
max_len = 100
forward_expansion = 256
src_pad_idx = english.vocab.stoi["<pad>"]

In [9]:
# Tensorboard

writer = SummaryWriter('runs/loss_plot')
step = 0

In [10]:
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=batch_size,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

In [11]:
embedding_size

512

In [12]:
model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)

In [13]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

pad_idx = english.vocab.stoi["<pad>"]

criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [14]:
if load_model:
    load_checkpoint(torch.load('mycheckpoint.pth.tar'), model, optimizer)
    

In [15]:
sentence = (
    "ein boot mit mehreren männern darauf wird von einem großen"
    "pferdegespann ans ufer gezogen."
)

sentence = "ein pferd geht unter einer brücke neben einem boot."

In [None]:

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    if save_model:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

    model.eval()
    translated_sentence = translate_sentence(
        model, sentence, german, english, device, max_length=50
    )

    print(f"Translated example sentence: \n {translated_sentence}")
    model.train()
    losses = []

    for batch_idx, batch in enumerate(train_iterator):
        # Get input and targets and get to cuda
        inp_data = batch.src.to(device)
        target = batch.trg.to(device)

        # Forward prop
        output = model(inp_data, target[:-1, :])

        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
        # doesn't take input in that form. For example if we have MNIST we want to have
        # output to be: (N, 10) and targets just (N). Here we can view it in a similar
        # way that we have output_words * batch_size that we want to send in into
        # our cost function, so we need to do some reshapin.
        # Let's also remove the start token while we're at it
        output = output.reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        optimizer.zero_grad()

        loss = criterion(output, target)
        losses.append(loss.item())

        # Back prop
        loss.backward()
        # Clip to avoid exploding gradient issues, makes sure grads are
        # within a healthy range
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        # Gradient descent step
        optimizer.step()

        # plot to tensorboard
        writer.add_scalar("Training loss", loss, global_step=step)
        step += 1

    mean_loss = sum(losses) / len(losses)
    scheduler.step(mean_loss)

# running on entire test data takes a while
score = bleu(test_data[1:100], model, german, english, device)
print(f"Bleu score {score * 100:.2f}")

[Epoch 0 / 5]
=> Saving checkpoint
Translated example sentence: 
 ['cellphone', 'harmonica', 'dog', 'harmonica', 'steeplechase', 'painting', 'belts', 'belts', 'painting', 'squash', 'driveway', 'harmonica', 'aiming', 'jogs', 'countryside', 'belts', 'painting', 'painting', 'painting', 'belts', 'kitten', 'countryside', 'pain', 'against', 'kitten', 'painting', 'harmonica', 'painting', 'countryside', 'tape', 'golfer', 'kitten', 'kitten', 'steeplechase', 'kitten', 'kitten', 'countryside', 'transporting', 'steeplechase', 'kitten', 'countryside', 'pain', 'transporting', 'ford', 'kitten', 'harmonica', 'steeplechase', 'glacier', 'golfer', 'golfer']


### Sample translation

In [None]:
sentence = (
    "ein boot mit mehreren männern darauf wird von einem großen"
    "pferdegespann ans ufer gezogen."
)

In [None]:
spacy_ger = spacy.load("de_core_news_sm")

# Create tokens using spacy and everything in lower case (which is what our vocab is)
if type(sentence) == str:
    tokens = [token.text.lower() for token in spacy_ger(sentence)]
else:
    tokens = [token.lower() for token in sentence]

# Add <SOS> and <EOS> in beginning and end respectively
tokens.insert(0, german.init_token)
tokens.append(german.eos_token)

# Go through each german token and convert to an index
text_to_indices = [german.vocab.stoi[token] for token in tokens]

# Convert to Tensor
sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

outputs = [english.vocab.stoi["<sos>"]]
for i in range(max_len):
    trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    with torch.no_grad():
        output = model(sentence_tensor, trg_tensor)
    best_guess = output.argmax(2)[-1, :].item()
    outputs.append(best_guess)

    if best_guess == english.vocab.stoi["<eos>"]:
        break

translated_sentence = [english.vocab.itos[idx] for idx in outputs]
# remove start token

print(translated_sentence[1:-1].join(' '))