In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import numpy as np
import time

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.utils.data import Dataset

from nltk.translate.bleu_score import corpus_bleu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Prepare data

In [2]:
PAD_IDX = 0
UNK_IDX = 1
SOS_IDX = 2
EOS_IDX = 3

class Language:
    def __init__(self, name):
        self.name = name
        self.word2idx = {PAD_IDX: "<PAD>", UNK_IDX: "<UNK>", SOS_IDX: "<SOS>", EOS_IDX: "<EOS>"}
        self.idx2word = ["<PAD>", "<UNK>", "<SOS>", "<EOS>"]
        self.sentence_list = []

    def build_vocab(self, sentence_list):
        self.idx2word += list(set([word for sentence in sentence_list for word in sentence]))
        self.word2idx = dict(zip(self.idx2word, range(0, len(self.idx2word))))

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicode2ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalize_string(s):
    s = unicode2ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
def read_data(language_name, path, data_type):
    file_lines = open(path + '%s.tok.%s' % (data_type, language_name), encoding='utf-8').read().strip().split('\n')
    
    if language_name == "en":
        sentence_list = [["<SOS>"]+[normalize_string(word) for word in line.split()]+["<EOS>"] for line in file_lines]
    else:
        sentence_list = [["<SOS>"]+[word for word in line.split()]+["<EOS>"] for line in file_lines]
    return sentence_list

path = "./iwslt-zh-en/"
train_data = [read_data("zh", path, "train"), read_data("en", path, "train")]

In [5]:
source_language = Language("zh")
source_language.build_vocab(train_data[0])

target_language = Language("en")
target_language.build_vocab(train_data[1])

### Dataset and Dataloader

In [6]:
class LanguageDataset(Dataset):
    def __init__(self, dataset):
        self.source_sentence = dataset[0]
        self.target_sentence = dataset[1]
        assert len(self.source_sentence) == len(self.target_sentence)
    
    def __len__(self):
        return len(self.source_sentence)
    
    def __getitem__(self, idx):
            
        source_idx_list = [source_language.word2idx[cur_word] if cur_word in source_language.word2idx else UNK_IDX 
                           for cur_word in self.source_sentence[idx]]
        target_idx_list = [target_language.word2idx[cur_word] if cur_word in target_language.word2idx else UNK_IDX 
                           for cur_word in self.target_sentence[idx]]
        return ((source_idx_list, target_idx_list), (len(source_idx_list), len(target_idx_list)))

In [8]:
MAX_SENTENCE_LENGTH = 200
BATCH_SIZE = 32

def padding(batch):
    padded_source_list = []
    padded_target_list = []
    source_length_list = []
    target_length_list = []
    
    for data in batch:
        
        if data[1][0] > MAX_SENTENCE_LENGTH or data[1][1] > MAX_SENTENCE_LENGTH:
            continue
        source_length_list.append(data[1][0])
        target_length_list.append(data[1][1])
        
        padded_source = np.pad(np.array(data[0][0]), pad_width = ((0, MAX_SENTENCE_LENGTH - data[1][0])), mode="constant", constant_values=0)
        padded_source_list.append(padded_source)
        
        padded_target = np.pad(np.array(data[0][1]), pad_width = ((0, MAX_SENTENCE_LENGTH - data[1][1])), mode="constant", constant_values=0)
        padded_target_list.append(padded_target)
        
    
    return ((torch.from_numpy(np.array(padded_source_list)), torch.from_numpy(np.array(padded_target_list))), (torch.from_numpy(np.array(source_length_list)), torch.from_numpy(np.array(target_length_list))))


train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, collate_fn=padding, shuffle=False)

### Encoder

In [9]:
class EncoderRNN(nn.Module):
    def __init__(self, embed_dim, hidden_dim, layer_num, vocab_size, batch_size):
        super(EncoderRNN, self).__init__()
        
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        #self.embedding.load_state_dict({'weight': torch.from_numpy(pretrained_embeddings)})
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=False)
    
    def init_hidden(self, batch_size):
        return torch.randn(self.layer_num, batch_size, self.hidden_dim, device=device).to(device)
    
    def forward(self, sentence_list, sentence_length_list):
        
        embed = pack(self.embedding(sentence_list), sentence_length_list, batch_first=True)
        batch_size, _ = sentence_list.size()
        hidden = self.init_hidden(batch_size)
        packed_outputs, hidden = self.gru(embed, hidden)
        outputs, _ = unpack(packed_outputs, batch_first=True)
        
        return outputs, hidden
        

In [10]:
def batch_sort(src_sentence_list, src_length_list, tgt_sentence_list):
    sort_idx = np.argsort(-src_length_list)
    return [src_sentence_list[sort_idx], src_length_list[sort_idx], tgt_sentence_list[sort_idx]]

### Decoder w/o attention

In [11]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_dim, hidden_dim, decoder_out_dim, layer_num, vocab_size):
        super(DecoderRNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.layer_num = layer_num
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(hidden_dim, decoder_out_dim)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        embed = F.relu(self.embedding(input))
        output, hidden = self.gru(embed, hidden)
        output = self.softmax(self.fc(output[0]))
        return output, hidden
    
    def init_hidden(self, batch_size):
        return torch.randn(self.layer_num, batch_size, self.hidden_dim, device=device)

### Decoder with attention

In [None]:
class AttentionDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(self.attn(torch.cat((embedded, hidden), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded, attn_applied), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output), dim=1)
        return output, hidden, attn_weights

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

### Training the Model

In [None]:
EPOCH_NUM = 15
log_step = 100

encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

for epoch in range(EPOCH_NUM):
    
    encoder.train()
    decoder.train()
    
    total_loss = 0
    
    for (batch, (sentence_pair, length_pair)) in train_loader:
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        loss = 0
        src_sentence_list, src_length_list, tgt_sentence_list = batch_sort(sentence_pair[0], length_pair[0], sentence_pair[1])
        encoder_output, encoder_hidden = encoder(src_sorted_sentence.to(device), src_length_list.to(device))
            
        decoder_input = torch.tensor([[target_language.word2idx['<SOS>']]] * BATCH_SIZE)
        decoder_hidden = encoder_hidden
    
        for target_length in range(1, tgt_sentence_list.size(1)):
            decoder_output, decoder_hidden = decoder(decoder_input.to(device), decoder_hidden.to(device))
            loss += criterion(decoder_output, tgt_sentence_list[:, target_length].to(device))
            decoder_input = tgt_sentence_list[:, target_length].unsqueeze(1)
                        
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        
        if batch % log_step == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, loss.item()))
            