In [1]:
import os
import sys
import math
from collections import Counter
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import nltk

In [2]:
# implementation of https://arxiv.org/pdf/1508.04025.pdf

In [3]:
# nltk.download('punkt')

In [4]:
# a simple dataset that does en-cn translation
def load_data(in_file):
    cn = []
    en = []
    num_examples = 0
    with open(in_file, 'r') as f:
        for line in f:
            line = line.strip().split("\t")
            
            en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
            # split chinese sentence into characters for simplification
            cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
    return en, cn

In [5]:
train_file = "./nmt/en-cn/train.txt"
dev_file = "./nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)

In [6]:
train_en[0:2]

[['BOS', 'anyone', 'can', 'do', 'that', '.', 'EOS'],
 ['BOS', 'how', 'about', 'another', 'piece', 'of', 'cake', '?', 'EOS']]

In [7]:
train_cn[0:2]

[['BOS', '任', '何', '人', '都', '可', '以', '做', '到', '。', 'EOS'],
 ['BOS', '要', '不', '要', '再', '來', '一', '塊', '蛋', '糕', '？', 'EOS']]

In [8]:
# build vocabulary
UNK_IDX = 0
PAD_IDX = 1

def build_dict(sentences, max_words = 50000):
    word_count = Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s] += 1
    vocab = word_count.most_common(max_words) # (word, count)
    total_words = len(vocab) + 2 # plus UNK and PAD
    word_dict = {w[0]: index+2 for index, w in enumerate(vocab)}
    word_dict["UNK"] = UNK_IDX
    word_dict["PAD"] = PAD_IDX
    return word_dict, total_words

In [9]:
en_dict, en_total_words = build_dict(train_en)

In [10]:
cn_dict, cn_total_words = build_dict(train_cn)

In [11]:
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}

In [12]:
# encode sentences into numbers
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len = True):
    
    length = len(en_sentences)
    out_en_sentences = [[en_dict.get(w, 0) for w in sen] for sen in en_sentences]
    out_cn_sentences = [[cn_dict.get(w, 0) for w in sen] for sen in cn_sentences]
    
    # sort sentences by english length to make sure each batch has similar length
    def len_argsort(seq):
        return sorted(range(len(seq)), key = lambda x : len(seq[x])) # index of smaller length comes first
    
    if sort_by_len:
        sorted_index = len_argsort(out_en_sentences)
        out_en_sentences = [out_en_sentences[i] for i in sorted_index]
        out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
    
    return out_en_sentences, out_cn_sentences

In [13]:
train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

In [14]:
# checking
k = 10000
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
print(" ".join([inv_en_dict[i] for i in train_en[k]]))

BOS 他 来 这 里 的 目 的 是 什 么 ？ EOS
BOS for what purpose did he come here ? EOS


In [15]:
def get_minibatches(n, minibatch_size, shuffle=True):
    idx_list = np.arange(0, n, minibatch_size)
    if shuffle:
        np.random.shuffle(idx_list)
    minibatches = []
    for idx in idx_list:
        minibatches.append(np.arange(idx, min(idx+minibatch_size, n)))
    return minibatches

In [16]:
get_minibatches(100, 10)

[array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
 array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69]),
 array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39]),
 array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([70, 71, 72, 73, 74, 75, 76, 77, 78, 79]),
 array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
 array([80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
 array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49]),
 array([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])]

In [17]:
# add padding to every sentence and make every sentence have the same length
def prepare_data(seqs):
    lengths = [len(seq) for seq in seqs]
    n_samples = len(seqs)
    max_len = np.max(lengths)
    
    x = np.zeros((n_samples, max_len)).astype('int32')
    x_lengths = np.array(lengths).astype('int32')
    for idx, seq in enumerate(seqs):
        x[idx, :lengths[idx]] = seq
    return x, x_lengths # sentences after padding, actual length before padding

In [18]:
def gen_examples(en_sentences, cn_sentences, batch_size):
    minibatches = get_minibatches(len(en_sentences), batch_size)
    all_examples = []
    for minibatch in minibatches:
        mb_en_sentences = [en_sentences[t] for t in minibatch]
        mb_cn_sentences = [cn_sentences[t] for t in minibatch]
        mb_x, mb_x_len = prepare_data(mb_en_sentences)
        mb_y, mb_y_len = prepare_data(mb_cn_sentences)
        all_examples.append((mb_x, mb_x_len, mb_y, mb_y_len))
    return all_examples

In [19]:
batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)

In [20]:
len(train_data)

228

In [21]:
train_data[0][0].shape

(64, 8)

In [22]:
torch.LongTensor(train_data[0][1]).sort(0, descending=True)

torch.return_types.sort(
values=tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]),
indices=tensor([41, 40, 46, 45, 44, 43, 42, 48, 49, 39, 38, 37, 36, 35, 47, 33, 57, 56,
        62, 61, 60, 59, 58, 31, 50, 55, 54, 53, 52, 51, 34,  0, 32,  8,  7, 13,
        12, 11, 10,  9, 15, 16,  6,  5,  4,  3,  2, 14,  1, 24, 23, 29, 28, 27,
        26, 25, 30, 17, 22, 21, 20, 19, 18, 63]))

In [23]:
dev_data = gen_examples(dev_en, dev_cn, batch_size)

In [24]:
class SimpleEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(SimpleEncoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        '''
        x: (batch_size, max_length_in_batch)
        lengths: tensor of shape (batch_size)
        '''
        sorted_len, sorted_idx = lengths.sort(0, descending=True) # sorted_len: (batch_size), sorted_idx: (batch_size)
        x_sorted = x[sorted_idx.long()] # every sentence is sorted from long to short in each batch now
        embedded = self.dropout(self.embed(x_sorted)) # (batch_size, max_length, hidden_size)
        
        # avoid computing the hidden state for padding using pack_padded_sequence, this function has to sort sentence by length first
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
        packed_out, hid = self.rnn(packed_embedded) # packed_out: (batch_size, max_length, hidden_size), hid: (1, batch, hidden_size)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) # out: (batch_size, max_length, hidden_size)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous() # out: (batch_size, max_length, hidden_size)
        hid = hid[:,original_idx.long()].contiguous() # the : selects the first dimention which is 1, index works on batch dimension only
        
        return out, hid[[-1]] # hid[-1]: (batch_size, hidden_size), hid[[-1]]: (1, batch_size, hidden_size)


class SimpleDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(SimpleDecoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx.long()]
        hid = hid[:, sorted_idx.long()] # (1, batch_size, max_length)

        y_sorted = self.dropout(self.embed(y_sorted))  # (batch_size, max_len, embed_size/hidden_size)

        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True) 
        out, hid = self.rnn(packed_seq, hid) # out: (batch_size, max_length, hidden_size), hid: (1, batch, hidden_size)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        output_seq = unpacked[original_idx.long()].contiguous() # output_seq: (batch_size, max_length, hidden_size)
        hid = hid[:, original_idx.long()].contiguous() # hid: (1, batch, hidden_size)

        output = F.log_softmax(self.out(output_seq), -1) # (batch_size, max_length, vocab_size)
        
        return output, hid

class SimpleSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(SimpleSeq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self,x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid = self.decoder(y, y_lengths, hid)
        return output, None # placeholder for attention
    
    def translate(self, x, x_lengths, y, max_length = 10):
        encoder_out, hid = self.encoder(x, x_lengths)
        batch_size = x.shape[0]
        preds = []
        for i in range(max_length):
            output, hid = self.decoder(y, torch.ones(batch_size).long().to(y.device), hid)
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
            
        return torch.cat(preds, 1), None

In [25]:
# masked cross entropy loss
class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def forward(self, input, target, mask):
        '''
        input: (batch_size, max_len, vocab_size)
        '''
        input = input.contiguous().view(-1, input.size(2)) # (batch_size * max_len, vocab_size)
        target = target.contiguous().view(-1, 1) # (batch_size * max_len, 1)
        mask = mask.contiguous().view(-1, 1) # (batch_size * max_length, 1)
        output = -input.gather(1, target) * mask # select the part in input where target is equal to 1 and then time mask
        output = torch.sum(output) / torch.sum(mask)

        return output

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dropout = 0.2
hidden_size = 100
encoder = SimpleEncoder(vocab_size=en_total_words,
                      hidden_size=hidden_size,
                      dropout=dropout)
decoder = SimpleDecoder(vocab_size=cn_total_words,
                      hidden_size=hidden_size,
                      dropout=dropout)
model = SimpleSeq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())

In [28]:
def evaluate(model, data):
    model.eval()
    total_num_words = total_loss = 0.
    with torch.no_grad():
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device).long()
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()

            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)

            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
            mb_out_mask = mb_out_mask.float()

            loss = loss_fn(mb_pred, mb_output, mb_out_mask)

            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words
    print("Evaluation loss", total_loss/total_num_words)

In [29]:
def train(model, data, num_epochs = 20):
    for epoch in range(num_epochs):
        model.train()
        total_num_words = total_loss = 0
        for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device).long() # (batch_size, max_len)
            mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
            # input does not include the last word
            mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
            # expected output is one offset off compared with the mb_input
            mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
            mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
            
            mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len) # (batch_size, max_length, vocab_size)
            
            # (1, max_length) < (batch_size, 1) => (batch_size, max_length)
            mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None] 
            # for each row, 1 indicates non-mask, 0 indicates the word is masked (the word is padding)
            mb_out_mask = mb_out_mask.float()
            
            loss = loss_fn(mb_pred, mb_output, mb_out_mask)
            
            num_words = torch.sum(mb_y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()
            
            if it % 100 == 0:
                print("Epoch", epoch, "iteration", it, "loss", loss.item())
                
        print("Epoch", epoch, "Training loss", total_loss/total_num_words)
        if epoch % 5 == 0:
            evaluate(model, dev_data)
        
train(model, train_data, num_epochs = 30)

Epoch 0 iteration 0 loss 8.110105514526367
Epoch 0 iteration 100 loss 5.299361228942871
Epoch 0 iteration 200 loss 5.251104831695557
Epoch 0 Training loss 5.4507395670388155
Evaluation loss 4.8519430413628495
Epoch 1 iteration 0 loss 4.411635875701904
Epoch 1 iteration 100 loss 4.724141597747803
Epoch 1 iteration 200 loss 4.8460564613342285
Epoch 1 Training loss 4.62219145066978
Epoch 2 iteration 0 loss 3.873622417449951
Epoch 2 iteration 100 loss 4.337270736694336
Epoch 2 iteration 200 loss 4.554322242736816
Epoch 2 Training loss 4.231444733532659
Epoch 3 iteration 0 loss 3.563133955001831
Epoch 3 iteration 100 loss 4.088422775268555
Epoch 3 iteration 200 loss 4.333175182342529
Epoch 3 Training loss 3.9670103747909575
Epoch 4 iteration 0 loss 3.3092098236083984
Epoch 4 iteration 100 loss 3.8547275066375732
Epoch 4 iteration 200 loss 4.176016330718994
Epoch 4 Training loss 3.767932853786133
Epoch 5 iteration 0 loss 3.1728157997131348
Epoch 5 iteration 100 loss 3.6784474849700928
Epoch 

In [30]:
def translate_dev(i):
    en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])
    print(en_sent)
    cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])
    print("".join(cn_sent))

    mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
    mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
    bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)

    translation, attn = model.translate(mb_x, mb_x_len, bos)
    translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
    trans = []
    for word in translation:
        if word != "EOS":
            trans.append(word)
        else:
            break
    print("".join(trans))

for i in range(100,120):
    translate_dev(i)
    print()

BOS you have nice skin . EOS
BOS 你 的 皮 膚 真 好 。 EOS
你有一个好。

BOS you 're UNK correct . EOS
BOS 你 部 分 正 确 。 EOS
你们现在了。

BOS everyone admired his courage . EOS
BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
每个人都知道他的名字

BOS what time is it ? EOS
BOS 几 点 了 ？ EOS
在哪裡？

BOS i 'm free tonight . EOS
BOS 我 今 晚 有 空 。 EOS
我們在那個男孩。

BOS here is your book . EOS
BOS 這 是 你 的 書 。 EOS
这是你的手錶。

BOS they are at lunch . EOS
BOS 他 们 在 吃 午 饭 。 EOS
他們在學校。

BOS this chair is UNK . EOS
BOS 這 把 椅 子 很 UNK 。 EOS
這個房間有一個很好。

BOS it 's pretty heavy . EOS
BOS 它 真 重 。 EOS
它太好。

BOS many attended his funeral . EOS
BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
每个人都在笑。

BOS training will be provided . EOS
BOS 会 有 训 练 。 EOS
這個工作很好。

BOS someone is watching you . EOS
BOS 有 人 在 看 著 你 。 EOS
你有很多人。

BOS i slapped his face . EOS
BOS 我 摑 了 他 的 臉 。 EOS
我的手臂斷了。

BOS i like UNK music . EOS
BOS 我 喜 歡 流 行 音 樂 。 EOS
我喜欢阅读。

BOS tom had no children . EOS
BOS T o m 沒 有 孩 子 。 EOS
汤姆没有任何人都来。

BOS please lock the door . EOS
BOS 請 把 門 鎖 上 。 EOS
請把門關上。

BOS tom has calmed

In [None]:
# implementation of https://arxiv.org/pdf/1508.04025.pdf

In [31]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)

    def forward(self, x, lengths):
        '''
        x: (batch_size, max_length_in_batch)
        lengths: tensor of shape (batch_size)
        '''
        sorted_len, sorted_idx = lengths.sort(0, descending=True) # sorted_len: (batch_size), sorted_idx: (batch_size)
        x_sorted = x[sorted_idx.long()] # every sentence is sorted from long to short in each batch now
        embedded = self.dropout(self.embed(x_sorted)) # (batch_size, max_length, embed_size)
        
        # avoid computing the hidden state for padding using pack_padded_sequence, this function has to sort sentence by length first
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
        packed_out, hid = self.rnn(packed_embedded) # packed_out: (batch_size, max_length, enc_hidden_size), hid: (2, batch, enc_hidden_size)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) # out: (batch_size, max_length, enc_hidden_size)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous() # out: (batch_size, max_length, enc_hidden_size)
        hid = hid[:,original_idx.long()].contiguous() # the : selects the first dimention which is 2, index works on batch dimension only
        
        hid = torch.cat([hid[-2], hid[-1]], dim = -1) # (batch_size, 2 * enc_hidden_size)
        hid = torch.tanh(self.fc(hid)).unsqueeze(0) # (1, batch_size, dec_hidden_size)
        return out, hid

class Attention(nn.Module):
    def __init__(self, enc_hidden_size, dec_hidden_size):
        super(Attention, self).__init__()

        self.enc_hidden_size = enc_hidden_size
        self.dec_hidden_size = dec_hidden_size

        self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)
        self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
        
    def forward(self, output, context, mask):
        '''
        output: the hidden output of each cell in the decoder, (batch_size, output_len, dec_hidden_size) -> notice in the training step, we feed in actual translated word. So we can compute hidden state for each word rather than feed one by one sequentially
        context: the out of encoder, (batch_size, context_len, 2 * enc_hidden_size)
        '''
        batch_size = output.size(0)
        output_len = output.size(1)
        input_len = context.size(1)
        
        context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(                
            batch_size, input_len, -1) # (batch_size, context_len, dec_hidden_size) -> this is to do W*hs, next step is to dot product with ht (the hidden state from decoder)
        
        attn = torch.bmm(output, context_in.transpose(1,2)) # (batch_size, output_len, context_len)

        attn.data.masked_fill(mask, -1e6) # minimize the effect of padding

        attn = F.softmax(attn, dim=2)  # (batch_size, output_len, context_len)

        context = torch.bmm(attn, context) # (batch_size, output_len, 2 * enc_hidden_size)
        
        output = torch.cat((context, output), dim=2) # (batch_size, output_len, 2 * enc_hidden_size + dec_hidden_size)

        output = output.view(batch_size*output_len, -1) # (batch_size * output_len, 2 * enc_hidden_size + dec_hidden_size)
        output = torch.tanh(self.linear_out(output)) # (batch_size * output_len, dec_hidden_size)
        output = output.view(batch_size, output_len, -1) # (batch_size, output_len, dec_hidden_size)
        return output, attn
        
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(enc_hidden_size, dec_hidden_size)
        self.rnn = nn.GRU(embed_size, dec_hidden_size, batch_first=True)
        self.out = nn.Linear(dec_hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    
    def create_mask(self, x_len, y_len):
        # a mask of shape x_len * y_len
        device = x_len.device
        max_x_len = x_len.max()
        max_y_len = y_len.max()
        x_mask = (torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None])
        x_mask = x_mask.float()
        y_mask = (torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None])
        y_mask = y_mask.float()
        mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
        return mask # (batch_size, x_len, y_len)
    
    def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx.long()]
        hid = hid[:, sorted_idx.long()] # (1, batch_size, max_length)

        y_sorted = self.dropout(self.embed(y_sorted))  # (batch_size, max_len, embed_size)

        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True) 
        out, hid = self.rnn(packed_seq, hid) # out: (batch_size, max_length, hidden_size), hid: (1, batch, hidden_size)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        output_seq = unpacked[original_idx.long()].contiguous() # output_seq: (batch_size, max_length, hidden_size)
        hid = hid[:, original_idx.long()].contiguous() # hid: (1, batch, hidden_size)

        mask = self.create_mask(y_lengths, ctx_lengths)
        output, attn = self.attention(output_seq, ctx, mask)
        output = F.log_softmax(self.out(output), -1) # (batch_size, max_length, vocab_size)
        
        return output, hid, attn


In [32]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid, attn = self.decoder(ctx=encoder_out, 
                    ctx_lengths=x_lengths,
                    y=y,
                    y_lengths=y_lengths,
                    hid=hid)
        return output, attn
    
    def translate(self, x, x_lengths, y, max_length=100):
        encoder_out, hid = self.encoder(x, x_lengths)
        preds = []
        batch_size = x.shape[0]
        attns = []
        for i in range(max_length):
            output, hid, attn = self.decoder(ctx=encoder_out, 
                    ctx_lengths=x_lengths,
                    y=y,
                    y_lengths=torch.ones(batch_size).long().to(y.device),
                    hid=hid)
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
            attns.append(attn)
        return torch.cat(preds, 1), torch.cat(attns, 1)

In [33]:
dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
                       embed_size=embed_size,
                      enc_hidden_size=hidden_size,
                       dec_hidden_size=hidden_size,
                      dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
                      embed_size=embed_size,
                      enc_hidden_size=hidden_size,
                       dec_hidden_size=hidden_size,
                      dropout=dropout)
model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())

In [35]:
import warnings; warnings.simplefilter('ignore')
train(model, train_data, num_epochs=30)

Epoch 0 iteration 0 loss 2.5341858863830566
Epoch 0 iteration 100 loss 3.253624439239502
Epoch 0 iteration 200 loss 3.719940423965454
Epoch 0 Training loss 3.1440185637054
Evaluation loss 3.361748992043007
Epoch 1 iteration 0 loss 2.454204797744751
Epoch 1 iteration 100 loss 3.115828275680542
Epoch 1 iteration 200 loss 3.635378122329712
Epoch 1 Training loss 3.039223692395219
Epoch 2 iteration 0 loss 2.367974281311035
Epoch 2 iteration 100 loss 3.0356693267822266
Epoch 2 iteration 200 loss 3.524312734603882
Epoch 2 Training loss 2.9415965439987297
Epoch 3 iteration 0 loss 2.325373649597168
Epoch 3 iteration 100 loss 2.894467353820801
Epoch 3 iteration 200 loss 3.4648382663726807
Epoch 3 Training loss 2.859146160381086
Epoch 4 iteration 0 loss 2.252025842666626
Epoch 4 iteration 100 loss 2.8076364994049072
Epoch 4 iteration 200 loss 3.3692665100097656
Epoch 4 Training loss 2.7840369354817716
Epoch 5 iteration 0 loss 2.193591833114624
Epoch 5 iteration 100 loss 2.7523467540740967
Epoch 5

In [36]:
for i in range(100,120):
    translate_dev(i)
    print()

BOS you have nice skin . EOS
BOS 你 的 皮 膚 真 好 。 EOS
你有很苍白。

BOS you 're UNK correct . EOS
BOS 你 部 分 正 确 。 EOS
你真的是重要的。

BOS everyone admired his courage . EOS
BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
每個人都認識他的。

BOS what time is it ? EOS
BOS 几 点 了 ？ EOS
它什么时候到？

BOS i 'm free tonight . EOS
BOS 我 今 晚 有 空 。 EOS
我今晚有好。

BOS here is your book . EOS
BOS 這 是 你 的 書 。 EOS
你的書在書。

BOS they are at lunch . EOS
BOS 他 们 在 吃 午 饭 。 EOS
他們正在吃午餐。

BOS this chair is UNK . EOS
BOS 這 把 椅 子 很 UNK 。 EOS
這件衣服。

BOS it 's pretty heavy . EOS
BOS 它 真 重 。 EOS
它是一樣的。

BOS many attended his funeral . EOS
BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
他的父親都是他的。

BOS training will be provided . EOS
BOS 会 有 训 练 。 EOS
鬱金會將會議。

BOS someone is watching you . EOS
BOS 有 人 在 看 著 你 。 EOS
有人在看你。

BOS i slapped his face . EOS
BOS 我 摑 了 他 的 臉 。 EOS
我把他的手錶了。

BOS i like UNK music . EOS
BOS 我 喜 歡 流 行 音 樂 。 EOS
我喜欢音樂。

BOS tom had no children . EOS
BOS T o m 沒 有 孩 子 。 EOS
汤姆没有人都不能看。

BOS please lock the door . EOS
BOS 請 把 門 鎖 上 。 EOS
請關門。

BOS tom has calme