## Seq2Se2: Neural Machine Translation with Attention

### 1. Preparing Data

In [1]:
import torch
import torchtext
import torch.nn as nn
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import functional
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm
import sys
import random
import spacy

In [2]:
de_tokenizer = get_tokenizer('spacy', 'de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', 'en_core_web_sm')

def de_yield_tokens(train_iter):
    for de_text, _ in train_iter:
        yield de_tokenizer(de_text[:-1].lower())
        
def en_yield_tokens(train_iter):
    for _, en_text in train_iter:
        yield en_tokenizer(en_text[:-1].lower())
        
special_tokens = ["<unk>", "<pad>", "<sos>", "<eos>"]

train_iter = Multi30k(split=("train"))
source_vocab = build_vocab_from_iterator(de_yield_tokens(train_iter), 
                                         min_freq = 1,
                                         specials = special_tokens)
source_vocab.set_default_index(source_vocab["<unk>"])

train_iter = Multi30k(split=("train"))
target_vocab = build_vocab_from_iterator(en_yield_tokens(train_iter),
                                        min_freq = 1,
                                        specials = special_tokens)
target_vocab.set_default_index(target_vocab["<unk>"])

In [3]:
# Load data
train_iter, valid_iter, test_iter = Multi30k()

en_text_pipeline = lambda x:  target_vocab(["<sos>"] + en_tokenizer(x)[:-1] + ["<eos>"]) 
de_text_pipeline = lambda x: source_vocab(["<sos>"] + de_tokenizer(x)[:-1] + ["<eos>"])

BATCH_SIZE = 100

def collate_batch_input(batch):
    source_list, target_list = [], []
    for source_text, target_text in batch:  
        text_seq = de_text_pipeline(source_text.lower()) 
        source_list.append(torch.tensor(text_seq, dtype=torch.int64))
        text_seq = en_text_pipeline(target_text.lower())
        target_list.append(torch.tensor(text_seq[:-1], dtype=torch.int64))
               
    source_tensor = pad_sequence(source_list, batch_first=True, padding_value=source_vocab["<pad>"])
    target_tensor = pad_sequence(target_list, batch_first=True, padding_value=source_vocab["<pad>"])
    
    return source_tensor, target_tensor, 


train_dataset = functional.to_map_style_dataset(train_iter)
valid_dataset = functional.to_map_style_dataset(valid_iter)
test_dataset = functional.to_map_style_dataset(test_iter)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE,
                          shuffle = True, collate_fn = collate_batch_input)
valid_loader = DataLoader(valid_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)

### 2. Define Model

#### Encoder

<img src="img/seq2seq8.png"/>

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, enc_hidden_dim, dec_hidden_dim, pad_index):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim, padding_idx=pad_index)
        self.gru = nn.GRU(embed_dim, enc_hidden_dim, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(enc_hidden_dim * 2, dec_hidden_dim)
        
    def forward(self, src):
        # src = [batch_size, src_len]
        
        embedded = self.embedding(src)
        # embedded = [batch_size, src_len, embed_dim]
        
        outputs, hidden = self.gru(embedded)
        # outputs = [batch_size, src_len, 2 * enc_hidden_dim] : bidirectional
        # hidden = [2, batch_size, enc_hidden_dim]
        
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        # hidden = [batch_size, dec_hidden_dim]
        
        return outputs, hidden

#### Attention

<img src="img/seq2seq9.png"/>

In [5]:
class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear((enc_hidden_dim * 2) + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs):
        
        # hidden = [batch_size, dec_hidden_dim]
        # encoder_outputs = [batch_size, src_len, enc_hidden_dim * 2]
        
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]
        
        # repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        # hidden = [batch_size, src_len, dec_hidden_dim]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy = [batch, src_len, dec_hidden_dim]
        
        attention = self.v(energy).squeeze(2)
        # attention = [batch_size, src_len]
        
        return F.softmax(attention, dim=1)

#### Decoder

<img src="img/seq2seq10.png" />



In [6]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, enc_hidden_dim, dec_hidden_dim, pad_index, 
                 attention):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embed_dim, padding_idx=pad_index)
        self.gru = nn.GRU((enc_hidden_dim * 2) + embed_dim, dec_hidden_dim, batch_first=True)
        self.fc = nn.Linear((enc_hidden_dim * 2) + dec_hidden_dim + embed_dim, output_dim)
        
    def forward(self, input, hidden, encoder_outputs):
        # input = [batch_size]
        # hidden = [batch_size, dec_hidden_dim]
        # encoder_outputs = [batch_size, src_len, enc_hidden_dim * 2]
        
        input = input.unsqueeze(1)
        # input = [batch_size, 1]
        
        embedded = self.embedding(input)
        # embedded = [batch_size, 1, embed_dim]
        
        a = self.attention(hidden, encoder_outputs)
        # a = [batch_size, src_len]
        a = a.unsqueeze(1)
        # a = [batch_size, 1, src_len]
        
        weighted = torch.bmm(a, encoder_outputs)
        # weighted = [batch_size, 1, enc_hidden_dim * 2]
                
        gru_input = torch.cat((embedded, weighted), dim = 2)
        # gru_input = [batch_size, 1, enc_hidden_dim * 2 + embed_dim]
        
        output, hidden = self.gru(gru_input, hidden.unsqueeze(0))
        # output = [batch_size, 1, dec_hidden_dim]
        # hidden = [1, batch_size, dec_hidden_size]
                
        assert (output.permute(1,0,2) == hidden).all()
        
        embedded = embedded.squeeze(1)
        # embedded = [batch_size, embed_dim]        
        output = output.squeeze(1)
        # weighted = [batch_size, enc_hidden_dim]
        weighted = weighted.squeeze(1)
        # weighted = [batch_size, enc_hidden_dim*2]
        
        prediction = self.fc(torch.cat((output, weighted, embedded), dim = 1))
        # prediction = [batch_size, output_dim]
        
        return prediction, hidden.squeeze(0)

#### Seq2Seq

In [7]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, teacher_force_ratio = 0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size)
        
        encoder_outputs, hidden = self.encoder(src)
                
        input = trg[:,0]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            output = output.unsqueeze(1)
            outputs[:, t, :] = output[:, 0, :]
            best_guess = output.argmax(2).squeeze(1)
            input = trg[:,t] if random.random() < teacher_force_ratio else best_guess
            
        return outputs

In [8]:
INPUT_DIM = len(source_vocab)
OUTPUT_DIM = len(target_vocab)
EMBED_DIM = 50
ENC_HIDDEN_DIM = 60
DEC_HIDDEN_DIM = 60

src_pad_idx = source_vocab["<pad>"]
trg_pad_idx = target_vocab["<pad>"]

attention = Attention(ENC_HIDDEN_DIM, DEC_HIDDEN_DIM)
enc = Encoder(INPUT_DIM, EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, src_pad_idx)
dec = Decoder(OUTPUT_DIM, EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, trg_pad_idx, attention)

model = Seq2Seq(enc, dec)

In [9]:
# Calculate the number of trainable parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 3,786,105 trainable parameters


In [10]:
def train(model, dataloader):
    model.train()
    epoch_loss = 0
    
    for src, trg in tqdm(dataloader, desc='training...', file=sys.stdout):
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        output_dim = output.shape[-1]

        output = output.reshape(-1, output_dim)
        trg = trg.reshape(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        epoch_loss += loss.item()
    
    return epoch_loss/len(dataloader)


def evaluate(model, dataloader):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for src, trg in dataloader:
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            trg = trg.reshape(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss/len(dataloader)  

In [23]:
N_EPOCHS = 10

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = trg_pad_idx)

for epoch in range(1, N_EPOCHS+1):
    train_loss = train(model, train_loader)
    val_loss = evaluate(model, valid_loader)
    print(f"| Epoch: {epoch}/{N_EPOCHS} | Train Loss: {train_loss} | Val Loss: {val_loss}")
    
eval_loss = evaluate(model, test_loader)
print("="*60)
print(eval_loss)    

torch.save(model, "ml_model_attention.pt")

training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [16:59<00:00,  3.52s/it]
| Epoch: 1/10 | Train Loss: 4.826547466475388 | Val Loss: 4.3754730658097705
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [17:27<00:00,  3.61s/it]
| Epoch: 2/10 | Train Loss: 4.2147707651401385 | Val Loss: 4.046748529780995
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [17:48<00:00,  3.68s/it]
| Epoch: 3/10 | Train Loss: 3.8770462669175245 | Val Loss: 3.8506043390794233
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [17:33<00:00,  3.63s/it]
| Epoch: 4/10 | Train Loss: 3.634876620358434 | Val Loss: 3.7206129377538506
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [17:38<00:00,  3.65s/it]
| Epoch: 5/10 | Train Loss: 3.4526881069972597 | Val Loss: 3.669716639952226
training...: 1

In [12]:
# model = torch.load("ml_model_attention.pt")

In [24]:
def translate_sentence(model, sentence):
    model.eval()
    tokens = ["<sos>"] + de_tokenizer(sentence.lower()) + ["<eos>"]
    sequence = source_vocab(tokens)
    sent_tensor = torch.LongTensor(sequence).unsqueeze(0)
    print(sent_tensor)
    
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(sent_tensor)
        
    outputs = target_vocab(["<sos>"])
        
    for _ in range(len(sequence)):
        previous_word = torch.LongTensor([outputs[-1]])
        with torch.no_grad():
            output, hidden = model.decoder(previous_word, hidden, encoder_outputs)
            best_guess = output.argmax(1).item()
            outputs.append(best_guess)
            if output.argmax(1).item() == source_vocab["<eos>"]:
                break
                
    tranlated_sent = target_vocab.lookup_tokens(outputs)
    return ' '.join(tranlated_sent[1:])


In [35]:
sentence = 'Ein kleines Mädchen mit einem Diadem, das jemandem auf dem Schoß sitzt und etwas isst.'
print(translate_sentence(model, sentence))

tensor([[   2,    5,   66,   25,   11,    6, 3178,    9,   39,  690,   12,   24,
          649,   31,   10,   75,  238,    4,    3]])
a little girl with a tiara is sitting on her lap and sitting and something . . . .


In [34]:
train_dataset[650]

('Ein kleines Mädchen mit einem Diadem, das jemandem auf dem Schoß sitzt und etwas isst.\n',
 'A little girl with a tiara eating in someones lap.\n')

#### Reference
1. https://github.com/bentrevett