## Basic Seq2Seq Model: Machine Translation

<img src="img/seq2seq1.png"/>

In [1]:
import torch
import torchtext
import torch.nn as nn
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import functional
from torch import optim
from tqdm import tqdm
import sys
import random

### 1. Prepare Data

#### Build Vocabulary for source and target languages

In [2]:
de_tokenizer = get_tokenizer('spacy', 'de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', 'en_core_web_sm')

def de_yield_tokens(train_iter):
    for de_text, _ in train_iter:
        yield de_tokenizer(de_text[:-1].lower())
        
def en_yield_tokens(train_iter):
    for _, en_text in train_iter:
        yield en_tokenizer(en_text[:-1].lower())
        
special_tokens = ["<unk>", "<pad>", "sos", "eos"]

train_iter = Multi30k(split=("train"))
source_vocab = build_vocab_from_iterator(de_yield_tokens(train_iter), 
                                         min_freq = 1,
                                         specials = special_tokens)
source_vocab.set_default_index(source_vocab["<unk>"])

train_iter = Multi30k(split=("train"))
target_vocab = build_vocab_from_iterator(en_yield_tokens(train_iter),
                                        min_freq = 1,
                                        specials = special_tokens)
target_vocab.set_default_index(target_vocab["<unk>"])

#### Build Dataloader

In [3]:
# Load data
train_iter, valid_iter, test_iter = Multi30k()

en_text_pipeline = lambda x:  target_vocab(["sos"] + en_tokenizer(x)[:-1] + ["eos"]) 
de_text_pipeline = lambda x: source_vocab(["sos"] + de_tokenizer(x)[:-1] + ["eos"])

BATCH_SIZE = 100

def collate_batch_input(batch):
    source_list, target_list = [], []
    for source_text, target_text in batch:  
        text_seq = de_text_pipeline(source_text.lower()) 
        source_list.append(torch.tensor(text_seq, dtype=torch.int64))
        text_seq = en_text_pipeline(target_text.lower())
        target_list.append(torch.tensor(text_seq[:-1], dtype=torch.int64))
               
    source_tensor = pad_sequence(source_list, batch_first=True, padding_value=source_vocab["<pad>"])
    target_tensor = pad_sequence(target_list, batch_first=True, padding_value=source_vocab["<pad>"])
    
    return source_tensor, target_tensor, 


train_dataset = functional.to_map_style_dataset(train_iter)
valid_dataset = functional.to_map_style_dataset(valid_iter)
test_dataset = functional.to_map_style_dataset(test_iter)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE,
                          shuffle = True, collate_fn = collate_batch_input)
valid_loader = DataLoader(valid_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE,
                          shuffle=False, collate_fn = collate_batch_input)

### 2. Define Model

#### Encoder

<img src="img/seq2seq2.png"/>

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, pad_idx,
                n_layers=1, dropout=0.5):
        super(Encoder, self).__init__()
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(input_dim, embed_dim, padding_idx = pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout=dropout, batch_first = True)
        
    def forward(self, src):
        embedded = self.embedding(src)
        output, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

#### Decoder

<img src="img/seq2seq3.png"/>

In [5]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim, pad_idx,
                n_layers=1, dropout=0.5):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.n_layers= n_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(output_dim, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, input, last_hidden, last_cell):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        outputs, (hidden, cell) = self.lstm(embedded, (last_hidden, last_cell))
        fc_out = self.fc(outputs.squeeze(0))
        return fc_out, hidden, cell
    

#### Seq2Seq

<img src="img/seq2seq4.png"/>

In [6]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
        assert encoder.hidden_dim == decoder.hidden_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers == decoder.n_layers, \
            "Number of layers of encoder and decoder must be equal!"
        
    def forward(self, src, trg, teacher_force_ratio = 0.5):
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size)

        hidden, cell = self.encoder(src)
        
        input = trg[:,0]
        
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t, :] = output[:, 0, :]
            best_guess = output.argmax(2).squeeze(1)
            input = trg[:,t] if random.random() < teacher_force_ratio else best_guess
            
        return outputs

In [39]:
INPUT_DIM = len(source_vocab)
OUTPUT_DIM = len(target_vocab)
EMBED_DIM = 50
HIDDEN_DIM = 60
N_LAYERS = 1
DROPOUT = 0.5
src_pad_idx = source_vocab["<pad>"]
trg_pad_idx = target_vocab["<pad>"]

enc = Encoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, src_pad_idx, N_LAYERS, DROPOUT)
dec = Decoder(OUTPUT_DIM, EMBED_DIM, HIDDEN_DIM, trg_pad_idx, N_LAYERS, DROPOUT)

model = Seq2Seq(enc, dec)

In [8]:
# Calculate the number of trainable parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,074,455 trainable parameters


In [9]:
def train(model, dataloader):
    model.train()
    epoch_loss = 0
    
    for src, trg in tqdm(dataloader, desc='training...', file=sys.stdout):
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        output_dim = output.shape[-1]

        output = output.reshape(-1, output_dim)
        trg = trg.reshape(-1)
        
        loss = criterion(output, trg)
        
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        epoch_loss += loss.item()
    
    return epoch_loss/len(dataloader)


def evaluate(model, dataloader):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for src, trg in dataloader:
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            trg = trg.reshape(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss/len(dataloader)  

In [10]:
N_EPOCHS = 10

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = trg_pad_idx)

for epoch in range(1, N_EPOCHS+1):
    train_loss = train(model, train_loader)
    val_loss = evaluate(model, valid_loader)
    print(f"| Epoch: {epoch}/{N_EPOCHS} | Train Loss: {train_loss} | Val Loss: {val_loss}")
    
eval_loss = evaluate(model, test_loader)
print("="*60)
print(eval_loss)    

torch.save(model, "ml_model.pt")

training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [11:14<00:00,  2.32s/it]
| Epoch: 1/10 | Train Loss: 6.262484307124697 | Val Loss: 5.749391685832631
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [09:45<00:00,  2.02s/it]
| Epoch: 2/10 | Train Loss: 5.619860931922649 | Val Loss: 5.468654025684703
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [09:49<00:00,  2.03s/it]
| Epoch: 3/10 | Train Loss: 5.356939120128237 | Val Loss: 5.252150795676491
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [09:52<00:00,  2.04s/it]
| Epoch: 4/10 | Train Loss: 5.161158211477872 | Val Loss: 5.122170535000888
training...: 100%|███████████████████████████████████████████████████████████████████| 290/290 [10:02<00:00,  2.08s/it]
| Epoch: 5/10 | Train Loss: 5.0203484732529216 | Val Loss: 4.961549758911133
training...: 100%|█

In [44]:
def translate_sentence(model, sentence):  
    model.eval()
    tokens = ["sos"] + de_tokenizer(sentence.lower()) + ["eos"]
    sequence = source_vocab(tokens)
    sent_tensor = torch.LongTensor(sequence).unsqueeze(0)
    
    with torch.no_grad():
        hidden, cell = model.encoder(sent_tensor)
        
    outputs = target_vocab(["sos"])
        
    for _ in range(len(sequence)):
        previous_word = torch.LongTensor([outputs[-1]])
        
        with torch.no_grad():
            output, hidden, cell = model.decoder(previous_word, hidden, cell)
            best_guess = output.argmax(1).item()
            outputs.append(best_guess)
            if output.argmax(1).item() == source_vocab["eos"]:
                break
                
    tranlated_sent = target_vocab.lookup_tokens(outputs)
    return ' '.join(tranlated_sent[1:])


In [45]:
model = torch.load("ml_model.pt")

sentence = 'Eine Frau mit schwarzen Haaren in gestreiftem Oberteil steht vor ein paar Ständen.'
print(translate_sentence(model, sentence))

a woman in a red shirt and a a shirt is a a a . .


In [36]:
train_dataset[600]

('Eine Frau mit schwarzen Haaren in gestreiftem Oberteil steht vor ein paar Ständen.\n',
 'A woman in a striped shirt and black hair stands facing some booths.\n')

#### Reference
1. https://github.com/bentrevett