In [1]:
import os
import sys

notebook_dir = os.path.dirname(os.path.abspath('__file__'))
sys.path.append(os.path.join(notebook_dir, '..'))

In [2]:
import pickle
from src.dataset import Vocabulary, ViTokenizer, EnTokenizer, MTDataset
from tqdm import tqdm
import time

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [3]:
with open('../ckpts/en_vocab.pkl', 'rb') as f:
    en_vocab = pickle.load(f)

with open('../ckpts/vi_vocab.pkl', 'rb') as f:
    vi_vocab = pickle.load(f)

en_vocab.n_words, vi_vocab.n_words

(32971, 29811)

In [4]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        # embedding layer
        self.embedding = nn.Embedding(input_size, hidden_size)
        # GRU cell
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

In [5]:
# Decoder
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
    
    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None, max_length=10, sos_token=2, device='cpu'):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(sos_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(max_length):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden, None

In [6]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)
    
    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)
        return context, weights
    
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward_step(self, input, hidden, encoder_outputs):
        embedded = self.dropout(self.embedding(input))
        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)
        
        return output, hidden, attn_weights
    
    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None, max_length=10, sos_token=2, device='cpu'):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(sos_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(max_length):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input forcing
                decoder_input = target_tensor[:, i].unsqueeze(1)
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions

In [7]:
def read_mt_dataset(type: str):
    # type : "train", "dev" or "test"
    folder_path = os.path.join("..", "data")
    en_file_path = os.path.join(folder_path, type, f"{type}.en")
    vi_file_path = os.path.join(folder_path, type, f"{type}.vi")
    with open(en_file_path, "r") as f:
        en_dataset = f.readlines()
        en_dataset = [line.strip() for line in en_dataset]
    with open(vi_file_path, "r") as f:
        vi_dataset = f.readlines()
        vi_dataset = [line.strip() for line in vi_dataset]
    return en_dataset, vi_dataset

In [8]:
def prepare_dataset():
    vi_tokenizer = ViTokenizer()
    en_tokenizer = EnTokenizer()

    train_en_sens, train_vi_sens = read_mt_dataset("train")
    val_en_sens, val_vi_sens = read_mt_dataset("dev")

    train_en_sens = train_en_sens[:100000]
    train_vi_sens = train_vi_sens[:100000]

    train_en_sens = [en_tokenizer.tokenize(sent) for sent in tqdm(train_en_sens)]
    train_vi_sens = [vi_tokenizer.tokenize(sent) for sent in tqdm(train_vi_sens)]

    val_en_sens = [en_tokenizer.tokenize(sent) for sent in tqdm(val_en_sens)]
    val_vi_sens = [vi_tokenizer.tokenize(sent) for sent in tqdm(val_vi_sens)]

    global en_vocab, vi_vocab

    train_dataset = MTDataset(
        inputs = [en_vocab.sentence_to_index(sent) for sent in train_en_sens],
        outputs= [vi_vocab.sentence_to_index(sent) for sent in train_vi_sens],
        max_length=50,
    )

    val_dataset = MTDataset(
        inputs = [en_vocab.sentence_to_index(sent) for sent in val_en_sens],
        outputs= [vi_vocab.sentence_to_index(sent) for sent in val_vi_sens],
        max_length=50,
    )

    return train_dataset, val_dataset

In [32]:
def train_epoch(
    encoder: nn.Module,
    decoder: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    encoder_optimizer: optim.Optimizer,
    decoder_optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: str='cpu',
):
    encoder.train()
    decoder.train()
    total_loss = 0
    for i, (inputs, outputs) in enumerate(dataloader):
        inputs = inputs.to(device)
        outputs = outputs.to(device)

        encoder_outputs, encoder_hidden = encoder(inputs)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, outputs, max_length=outputs.size(1))

        # print(decoder_outputs.size())
        # print(outputs.size())

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            outputs.view(-1),
        )

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

        print(loss.item())

    return total_loss / len(dataloader)

In [20]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [21]:
def train(
    encoder: nn.Module,
    decoder: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    n_epochs: int=10,
    lr: float=0.001,
    print_every: int=2,
    plot_every: int=2,
    device: str='cpu',
):
    encoder.train()
    decoder.train()
    start = time.time()
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(
            encoder, decoder, train_loader,
            encoder_optimizer, decoder_optimizer, criterion
        )
        print_loss_total += loss
        plot_loss_total += loss
        
        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))
        
        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
    
    return plot_losses

In [12]:
train_dataset, val_dataset = prepare_dataset()

100%|██████████| 100000/100000 [00:06<00:00, 15050.40it/s]
100%|██████████| 100000/100000 [00:10<00:00, 9332.05it/s]
100%|██████████| 18719/18719 [00:01<00:00, 16938.25it/s]
100%|██████████| 18719/18719 [00:01<00:00, 11722.99it/s]


In [13]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [14]:
X, y = next(iter(train_loader))
X.shape, y.shape

(torch.Size([64, 50]), torch.Size([64, 50]))

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [None]:
hidden_size = 256

encoder = EncoderRNN(en_vocab.n_words, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, vi_vocab.n_words).to(device)

losses = train(encoder, decoder, train_loader, 1, print_every=1, plot_every=1, device=device)

10.370078086853027
9.850102424621582
9.371803283691406
8.836904525756836
8.406574249267578
7.8715739250183105
7.491021633148193
7.1359992027282715
6.939545154571533
6.82679557800293
6.540205478668213
6.3052778244018555
6.010951519012451
5.901511192321777
6.061260223388672
5.1821608543396
4.954463481903076
4.847982883453369
4.688955307006836
4.824389934539795
4.397134780883789
4.372640132904053
3.969245195388794
3.9526941776275635
3.910526752471924
3.8788366317749023
4.019517421722412
3.6609811782836914
3.726144790649414
3.734893798828125
3.5194473266601562
3.478846788406372
3.5124731063842773
3.1087441444396973
3.438567876815796
3.709352493286133
3.3611693382263184
3.546229839324951
3.347602605819702
3.101534128189087
2.869824171066284
2.959648370742798
3.1138906478881836
3.2847185134887695
3.2812137603759766
2.9202442169189453
2.9269304275512695
2.8253705501556396
2.9735934734344482
3.2509396076202393
2.8569347858428955
2.9640698432922363
2.997912645339966
3.0266201496124268
3.0470914

KeyboardInterrupt: 

In [None]:
# evaluate
def evaluate(
    encoder: nn.Module, 
    decoder: nn.Module, 
    sentence: str,
    en_tokenizer: EnTokenizer,
):
    global en_vocab, vi_vocab
    encoder.eval()
    decoder.eval()
    with torch.inference_mode():
        inputs = en_tokenizer.tokenize(sentence)
        # print(inputs)
        inputs = en_vocab.sentence_to_index(inputs)
        inputs = torch.tensor(inputs, dtype=torch.long, device=device).unsqueeze(0)
        # print(inputs)
        
        encoder_outputs, encoder_hidden = encoder(inputs)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == vi_vocab.word_to_index(vi_vocab.EOS_TOKEN):
                decoded_words.append(vi_vocab.EOS_TOKEN)
                break
            decoded_words.append(vi_vocab.index_to_word(idx.item()))
    return decoded_words, decoder_attn

In [50]:
train_dataset[1]

(tensor([    2,   866,  4307, 15825,     4, 20936,     4,     8,   442,    13,
         25877,   463,    98,  1397,    36,     5,  3418,  8625,   976,  7294,
             7,     3,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]),
 tensor([    2,   998,  2093,     4,   518,   993,     4,    71,  5490,     4,
           373,   205,   434,     4,    60, 10404,     4,    13,     9,   348,
           499,   814,  1277,  2620,     5,     3,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]))

In [51]:
en_vocab.indexes_to_sentence([    2,    34,  1771,    23,     8, 18180,     7,     3])

['<SOS>', 'It', 'begins', 'with', 'a', 'countdown', '.', '<EOS>']

In [52]:
en_tokenizer = EnTokenizer()
evaluate(encoder, decoder, "It begins with a countdown.", en_tokenizer)

['It', 'begins', 'with', 'a', 'countdown', '.']
tensor([[    2,    34,  1771,    23,     8, 18180,     7,     3]])


(['<SOS>',
  'Tôi',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>',
  '<PAD>'],
 tensor([[[0.0155, 0.0153, 0.0810, 0.2713, 0.3418, 0.0640, 0.0480, 0.1632],
          [0.0171, 0.0166, 0.0818, 0.2666, 0.3399, 0.0650, 0.0495, 0.1634],
          [0.0185, 0.0180, 0.0833, 0.2627, 0.3352, 0.0665, 0.0521, 0.1637],
          [0.0270, 0.0265, 0.0873, 0.2467, 0.3065, 0.0735, 0.0638, 0.1685],
          [0.0605, 0.0589, 0.1086, 0.1989, 0.2242, 0.0978, 0.0933, 0.1578],
          [0.0693, 0.0675, 0.1119, 0.1873, 0.2067, 0.1025, 0.0999, 0.1550],
          [0.0712, 0.0694, 0.1125, 0.1848, 0.2031, 0.1034, 0.1012, 0.1543],
          [0.0717, 0.0699, 0.1127, 0.1841, 0.2022, 0.1037, 0.1015, 0.1542],
          [0.0719, 0.0701, 0.1128, 0.1838, 0.2018, 0.1038, 0.1016, 0.1541],
          [0.0720, 0.0702, 0.1128, 0.1837, 0.2016, 0.1038, 0.1017, 0.1541]]]))