In [1]:
from __future__ import unicode_literals, print_function, division

# import basic lib
import re
import math
import random
import string
import unicodedata
import numpy as np
from io import open

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_

# import helper function
import masked_cross_entropy
from infer_eval import bleu

# check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SOS_idx = 0
EOS_idx = 1
UNK_idx = 2
PAD_idx = 3

USE_CUDA = True

class Preprocessor:
    '''
    class for preprocessing
    '''
    def __init__(self, name):
        '''
        initialize vocab and counter
        '''
        self.name = name
        self.w2idx = {"<sos>" : 0, "<eos>" : 1, "<unk>" : 2, "<pad>" : 3}
        self.counter = {}
        self.idx2w = {0: "<sos>", 1: "<eos>", 2:"<unk>", 3:"<pad>"}
        self.num = 4

    def SentenceAdder(self, sentence):
        '''
        Add a sentence to dataset
        '''
        for word in sentence.split(' '):
            self.WordAdder(word)

    def WordAdder(self, word):
        '''
        Add single word to dataset and update vocab and counter
        '''
        if word in self.w2idx:
            self.counter[word] += 1
        else:
            self.w2idx[word] = self.num
            self.counter[word] = 1
            self.idx2w[self.num] = word
            self.num += 1
            
    def trim(self, min_count=1):
        '''
        Trim to remove non-frequent word
        '''
        keep = []
        for k, v in self.counter.items():
            if v >= min_count: keep.append(k)
        print(self.name+':')
        print('Total words', len(self.w2idx))
        print('After Trimming', len(keep))
        print('Keep Ratio %', 100 * len(keep) / len(self.w2idx))
        self.w2idx = {"<sos>" : 0, "<eos>" : 1, "<unk>" : 2, "<pad>" : 3}
        self.counter = {}
        self.idx2w = {0: "<sos>", 1: "<eos>", 2:"<unk>", 3:"<pad>"}
        self.num = 4
        for w in keep:
            self.WordAdder(w)

In [3]:
def Uni2Ascii(s):
    '''
    transfer from unicode to ascii
    '''
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                    if unicodedata.category(c) != 'Mn')

def StrCleaner(s):
    '''
    trim, delete non-letter and lowercase string
    '''
    s = Uni2Ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def DataReader(path, lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open(path, encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    #pairs = [[StrCleaner(s) for s in l.split('<------>')] for l in lines]
    pairs = [[s.lower() for s in l.split('<------>')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Preprocessor(lang2)
        output_lang = Preprocessor(lang1)
    else:
        input_lang = Preprocessor(lang1)
        output_lang = Preprocessor(lang2)

    return input_lang, output_lang, pairs

In [4]:
MIN_LENGTH = 3
MAX_LENGTH = 50

def filterPair(p):
    '''
    Filter to get expected pairs with specific length
    '''
    return MIN_LENGTH <= len(p[0].split(' ')) <= MAX_LENGTH and \
        MIN_LENGTH <= len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [5]:
def prepareData(path, lang1, lang2, reverse=True):
    input_lang, output_lang, pairs = DataReader(path, lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.SentenceAdder(pair[0])
        output_lang.SentenceAdder(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.num)
    print(output_lang.name, output_lang.num)
    return input_lang, output_lang, pairs


src, tgt, pairs = prepareData('data/train.txt', 'english', 'chinese')
src.trim()
tgt.trim()
print(random.choice(pairs))

Reading lines...
Read 90000 sentence pairs
Trimmed to 77083 sentence pairs
Counting words...
Counted words:
chinese 46064
english 30499
chinese:
Total words 46064
After Trimming 46060
Keep Ratio % 99.99131642931573
english:
Total words 30499
After Trimming 30495
Keep Ratio % 99.9868848158956
['周恩来 同志 亲自 制定 中国 援外 八 原则 , 至今 仍是 我们 开展 国际 经济 技术 合作 的 重要 指南 .', "comrade zhou enlai worked out eight principles on supporting foreign countries , which thus far still serve as an important guidance for developing international economic and technological cooperation . comrade xiaoping was the chief architect of china 's reform and opening up ."]


In [6]:
test_src, test_tgt, test_pairs = prepareData('data/test.txt', 'english', 'chinese')
test_src.trim()
test_tgt.trim()
print(random.choice(test_pairs))
test_src.w2idx, test_src.idx2w, test_src.num = src.w2idx, src.idx2w, src.num
test_tgt.w2idx, test_tgt.idx2w, test_tgt.num = tgt.w2idx, tgt.idx2w, tgt.num
test_pairs.sort(key=lambda x: len(x[0].split()))

Reading lines...
Read 10000 sentence pairs
Trimmed to 8572 sentence pairs
Counting words...
Counted words:
chinese 17294
english 13049
chinese:
Total words 17294
After Trimming 17290
Keep Ratio % 99.9768705909564
english:
Total words 13049
After Trimming 13045
Keep Ratio % 99.96934631006208
['加快 国际 通道 建设 .', 'we should step up the building of international communications .']


In [37]:
def sentence2idx(preprocessor, sentence):
    '''
    Read sentence and translate into word index plus eos
    '''
    return [SOS_idx] + [preprocessor.w2idx[w] if w in preprocessor.w2idx \
            else UNK_idx for w in sentence.split(' ')] + [EOS_idx]

def pad(seq, max_len):
    '''
    Add padding to sentence with different length
    '''
    seq += [PAD_idx for i in range(max_len - len(seq))]
    return seq

def random_batch(src, tgt, pairs, batch_size, batch_idx):
    '''
    Randomly generate batch data
    '''
    inputs, target = [], []
    
    # Choose batch
    for s in pairs[batch_idx*batch_size:(batch_idx+1)*batch_size]:
        inputs.append(sentence2idx(src, s[0]))
        target.append(sentence2idx(tgt, s[1]))
        
    # Sort by length
    seq_pairs = sorted(zip(inputs, target), key=lambda p: len(p[0]), reverse=True)
    inputs, target = zip(*seq_pairs)
    
    # Obtain length of each sentence and pad
    input_lens = [len(s) for s in inputs]
    input_max = max(input_lens)
    input_padded = [pad(s, input_max) for s in inputs]
    target_lens = [len(s) for s in target]
    target_max = max(target_lens)
    target_padded = [pad(s, target_max) for s in target]

    # Create Variable
    if USE_CUDA:
        input_vars = Variable(torch.LongTensor(input_padded).cuda()).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens).cuda())
        target_vars = Variable(torch.LongTensor(target_padded).cuda()).transpose(0, 1)
        target_lens = Variable(torch.LongTensor(target_lens).cuda())
    else:
        input_vars = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens))
        target_vars = Variable(torch.LongTensor(target_padded)).transpose(0, 1)
        target_lens = Variable(torch.LongTensor(target_lens))

    return input_vars, input_lens, target_vars, target_lens

def user_input(inputs, src):
    inp_list = [sentence2idx(src, inputs)]
    input_lens = [len(s) for s in inp_list]
    input_max = max(input_lens)
    input_padded = [pad(s, input_max) for s in inp_list]
    if USE_CUDA:
        input_vars = Variable(torch.LongTensor(input_padded).cuda()).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens).cuda())
    else:
        input_vars = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens))
    return input_vars, input_lens

In [8]:
class Encoder(nn.Module):
    '''
    Define encoder and forward process
    '''
    def __init__(self, dim_input, dim_embed, dim_hidden, num_layers, dropout):
        super(Encoder, self).__init__()
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_embed = dim_embed
        self.embed = nn.Embedding(dim_input, dim_embed)
        self.cell = nn.GRU(dim_embed, dim_hidden, 
                          num_layers, dropout=dropout, 
                          bidirectional=True)
        
    def init_hidden(self):
        if USE_CUDA:
            return Variable(torch.zeros(self.n_layers, 1, self.hidden_size).cuda())
        else:
            return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
        
    def forward(self, inputs, inputs_lens, hidden=None):
        '''
        We need to sum the outputs since bi-diretional is used
        '''
        embedded = self.embed(inputs)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, inputs_lens)
        outputs, hidden = self.cell(packed, hidden)
        outputs, output_lengths = nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.dim_hidden] + \
                    outputs[:, :, self.dim_hidden:]
        return outputs, hidden


class Attention(nn.Module):
    '''
    Define attention mechanism
    '''
    def __init__(self, dim_hidden):
        super(Attention, self).__init__()
        self.dim_hidden = dim_hidden
        # 2*dim_hidden is needed since bi-direction is used
        self.attn = nn.Linear(2*self.dim_hidden, dim_hidden)
        self.v = nn.Parameter(torch.rand(dim_hidden))
        stdv = 1. / math.sqrt(self.v.size(0))
        self.v.data.uniform_(-stdv, stdv)

    def forward(self, hidden, encoder_outputs):
        timestep = encoder_outputs.size(0)
        h = hidden.repeat(timestep, 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(0, 1)
        scores = self.score(h, encoder_outputs)
        return F.relu(scores).unsqueeze(1)

    def score(self, hidden, encoder_outputs):
        e = F.softmax(self.attn(torch.cat([hidden, encoder_outputs], 2)),dim=1)
        e = e.transpose(1, 2)
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)
        e = torch.bmm(v, e)
        return e.squeeze(1)


class Decoder(nn.Module):
    '''
    Define decoder with attention
    '''
    def __init__(self, dim_embed, dim_hidden, dim_output, num_layers, dropout):
        super(Decoder, self).__init__()
        self.dim_embed = dim_embed
        self.dim_hidden = dim_hidden
        self.dim_output = dim_output
        self.num_layers = num_layers

        self.embed = nn.Embedding(dim_output, dim_embed)
        self.dropout = nn.Dropout(dropout, inplace=True)
        self.attention = Attention(dim_hidden)
        self.cell = nn.GRU(dim_hidden + dim_embed, dim_hidden,
                          num_layers, dropout=dropout)
        self.out = nn.Linear(2*dim_hidden, dim_output)

    def forward(self, inputs, last_hidden, encoder_outputs):
        
        embedded = self.embed(inputs).unsqueeze(0)  # (1,B,N)
        embedded = self.dropout(embedded)
        
        attn_weights = self.attention(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # (B,1,N)
        context = context.transpose(0, 1)  # (1,B,N)
        
        rnn_input = torch.cat([embedded, context], 2)
        output, hidden = self.cell(rnn_input, last_hidden)
        output = output.squeeze(0)  # (1,B,N) -> (B,N)
        context = context.squeeze(0)
        output = self.out(torch.cat([output, context], 1))
        output = F.log_softmax(output, dim=1)
        return output, hidden, attn_weights

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_len, tgt, tgt_len, teacher_forcing_ratio=0.5):
        batch_size = src.size(1)
        max_len = tgt.size(0)
        vocab_size = self.decoder.dim_output
        if USE_CUDA:
            outputs = Variable(torch.zeros(max_len, batch_size, vocab_size).cuda())
        else:
            outputs = Variable(torch.zeros(max_len, batch_size, vocab_size))
        encoder_output, hidden = self.encoder(src, src_len)
        hidden = hidden[:self.decoder.num_layers]
        # Put <sos> at first position
        if USE_CUDA:
            output = Variable(tgt.data[0, :].cuda())
        else:
            output = Variable(tgt.data[0, :])
        for t in range(1, max_len):
            output, hidden, attn_weights = self.decoder(
                    output, hidden, encoder_output)
            outputs[t] = output
            # Randomly choose whether to use teacher force or not
            is_teacher = random.random() < teacher_forcing_ratio
            top1 = output.data.max(1)[1]
            if USE_CUDA:
                output = Variable(tgt.data[t].cuda() if is_teacher else top1.cuda())
            else:
                output = Variable(tgt.data[t] if is_teacher else top1)
        return outputs
    
    def inference(self, src, src_len, max_len = MAX_LENGTH):
        pred_idx = []
        batch_size = src.size(1)
        vocab_size = self.decoder.dim_output
        if USE_CUDA:
            outputs = Variable(torch.zeros(max_len, batch_size, vocab_size).cuda())
        else:
            outputs = Variable(torch.zeros(max_len, batch_size, vocab_size))
        
        encoder_output, hidden = self.encoder(src, src_len)
        hidden = hidden[:self.decoder.num_layers]
        # Put <sos> at first position
        if USE_CUDA:
            output = Variable(src.data[0, :].cuda())
        else:
            output = Variable(src.data[0, :])
        pred_idx.append(SOS_idx)
        for t in range(1, max_len):
            output, hidden, attn_weights = self.decoder(
                    output, hidden, encoder_output)
            outputs[t] = output
            top1 = output.data.max(1)[1]
            pred_idx.append(top1.item())
            if USE_CUDA:
                output = Variable(top1.cuda())
            else:
                output = Variable(top1)
            if top1 == EOS_idx: break
        return outputs, pred_idx

In [9]:
batch_size = 64
hidden_size = 512
embed_size = 256
encoder_n_layers = 2
decoder_n_layers = 1
encoder_test = Encoder(src.num, embed_size, hidden_size, encoder_n_layers, dropout=0.2)
decoder_test = Decoder(embed_size, hidden_size, tgt.num, decoder_n_layers, dropout=0.2)

  "num_layers={}".format(dropout, num_layers))


In [10]:
net = Seq2Seq(encoder_test,decoder_test).cuda()
print(net)

Seq2Seq(
  (encoder): Encoder(
    (embed): Embedding(46064, 256)
    (cell): GRU(256, 512, num_layers=2, dropout=0.2, bidirectional=True)
  )
  (decoder): Decoder(
    (embed): Embedding(30499, 256)
    (dropout): Dropout(p=0.2, inplace)
    (attention): Attention(
      (attn): Linear(in_features=1024, out_features=512, bias=True)
    )
    (cell): GRU(768, 512, dropout=0.2)
    (out): Linear(in_features=1024, out_features=30499, bias=True)
  )
)


In [11]:
opt = optim.Adam(net.parameters(),lr=0.0001)

In [12]:
net.load_state_dict(torch.load('./saved_model_1.pt'))

In [None]:
grad_clip = 5
num_batch = len(pairs) // batch_size
print_every_batches = 200
save_every_batches = 200
valid_every_epochs = print_every_batches
pairs.sort(key=lambda x: len(x[0].split()))

for epoch in range(1,50000):
    total_loss = 0
    tmp_loss = 0
    for batch_idx in range(num_batch):
        input_batches, input_lengths,\
            target_batches, target_lengths = random_batch(src,tgt,pairs,batch_size,batch_idx)
        
        opt.zero_grad()
        output = net(input_batches, input_lengths, target_batches, target_lengths)

        #loss = masked_cross_entropy.compute_loss(
        #    output.transpose(0, 1).contiguous(),
        #    target_batches.transpose(0, 1).contiguous(),
        #    target_lengths
        #)
        loss = F.nll_loss(output[1:].view(-1,tgt.num),
                          target_batches[1:].contiguous().view(-1),
                          ignore_index=PAD_idx)
        
        tmp_loss += loss.item()
        if (batch_idx+1) % save_every_batches == 0:
            torch.save(net.state_dict(), './saved.pt')
        clip_grad_norm_(net.parameters(), grad_clip)
        loss.backward()
        opt.step()
        if (batch_idx+1) % print_every_batches == 0:
            opt.zero_grad()
            input_batches, input_lengths,\
                target_batches, target_lengths = random_batch(src,tgt,pairs,batch_size,batch_idx)
            for test_idx in range(1):
                _, pred = net.inference(input_batches[:,test_idx].reshape(input_lengths[0].item(),1),input_lengths[0].reshape(1))
                inp = ' '.join([src.idx2w[t] for t in input_batches[:,test_idx].cpu().numpy()])
                mt = ' '.join([tgt.idx2w[t] for t in pred if t!= PAD_idx])
                ref = ' '.join([tgt.idx2w[t] for t in target_batches[:,test_idx].cpu().numpy()])
                print('INPUT:\n' + inp)
                print('REF:\n' + ref)
                print('PREDICTION:\n' + mt)
                print('BLEU = %f' % bleu([mt],[[ref]],4))
                print("------")
            output = net(input_batches, input_lengths, target_batches, target_lengths)

            #loss = masked_cross_entropy.compute_loss(
            #    output.transpose(0, 1).contiguous(),
            #    target_batches.transpose(0, 1).contiguous(),
            #    target_lengths
            #)
            loss = F.nll_loss(output[1:].view(-1,tgt.num),
                              target_batches[1:].contiguous().view(-1),
                              ignore_index=PAD_idx)
            print("Epoch %d Batch Num %d Train Loss: %f Test Loss: %f"%(epoch, batch_idx+1, tmp_loss/print_every_batches, loss.item()))
            tmp_loss = 0

    print('epoch %d finished !'%(epoch))
    random.shuffle(pairs)

INPUT:
<sos> 这是 我国 几千 年 来 最 深刻 , 最 伟大 的 社会变革 . <eos>
REF:
<sos> this has been the most profound and the greatest social change that has ever taken place in china over the past several thousand years . <eos> <pad> <pad> <pad>
PREDICTION:
<sos> this has been the most profound and the greatest that the has ever in china 's social and social years . <eos>
BLEU = 0.388644
------
Epoch 1 Batch Num 200 Train Loss: 1.425621 Test Loss: 1.522334
INPUT:
<sos> 这次 会议 是 总政治部 委托 北京地区 军队 院校 协作 中心 在 装备 指挥 技术 学院 召开 的 . <eos>
REF:
<sos> this forum was held in the armament command skills academy by the beijing area military academy coordinating center with the authorization of the general political department . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
PREDICTION:
<sos> this forum was held by the armament department of the armament military academy coordinating skills , coordinating the general political department , coordinating the coordinating of th

Epoch 2 Batch Num 1200 Train Loss: 2.732009 Test Loss: 2.522732
epoch 2 finished !
INPUT:
<sos> 据介绍 , 这 两 大 频道 的 开设 , 体现 了 中央 电视台 对 频道 整体 布局 的 日臻 全面 , 形成 了 一个 以 一套 节目 为 龙头 , 以 综合 频道 , 专业 频道 门类 日趋 齐全 为 特征 的 网络 体系 . <eos>
REF:
<sos> it is learned that the launch of the two channels shows cctv 's increasingly comprehensive overall programming . it has established a network with program one as its principal channel and included increasingly complete categories of general and special channels in its programming . <eos> <pad> <pad> <pad> <pad>
PREDICTION:
<sos> it has learned that the establishment of the central system has increasingly cctv , it has increasingly increasingly a comprehensive channel of the overall channels of its overall channels and its overall channels of its overall channels and its overall channels . <eos>
BLEU = 0.181869
------
Epoch 3 Batch Num 200 Train Loss: 2.642128 Test Loss: 2.458873


# Check Inference result of train data

In [30]:
pairs.sort(key=lambda x: len(x[0].split()))

In [33]:
from infer_eval import bleu

test_range = 50
ave_bleu = 0
for test_idx in range(test_range):
    input_batches, input_lengths,\
        target_batches, target_lengths = random_batch(src,tgt,pairs,1,test_idx)
    _, pred = net.inference(input_batches[:,0].reshape(input_lengths[0].item(),1),input_lengths[0].reshape(1))
    inp = ' '.join([src.idx2w[t] for t in input_batches[:,0].cpu().numpy()])
    mt = ' '.join([tgt.idx2w[t] for t in pred if t!= PAD_idx])
    ref = ' '.join([tgt.idx2w[t] for t in target_batches[:,0].cpu().numpy() if t != PAD_idx])
    print('INPUT:\n' + inp)
    print('REF:\n' + ref)
    print('PREDICTION:\n' + mt)
    tmp_score = bleu([mt],[[ref]],4)
    ave_bleu += tmp_score
    print('BLEU = %f' % tmp_score)
    print("------")
print('Average BLEU = '+str(ave_bleu/test_range))

INPUT:
<sos> 江苏 宜兴市 . <eos>
REF:
<sos> yixing city , jiangsu . <eos>
PREDICTION:
<sos> yixing city , jiangsu was <eos>
BLEU = 0.643459
------
INPUT:
<sos> 总理 朱鎔基 二零零一年八月二日 <eos>
REF:
<sos> prime minister zhu rongji 2 august 2001 <eos>
PREDICTION:
<sos> prime minister zhu rongji 2 august 2001 <eos>
BLEU = 1.000000
------
INPUT:
<sos> 再说 科索沃 . <eos>
REF:
<sos> or take kosovo . <eos>
PREDICTION:
<sos> take take kosovo . <eos>
BLEU = 0.537285
------
INPUT:
<sos> 1995年 离休 . <eos>
REF:
<sos> he retired in 1995 . <eos>
PREDICTION:
<sos> in 1995 and the . <eos>
BLEU = 0.698534
------
INPUT:
<sos> 谢谢 各位 ! <eos>
REF:
<sos> thank you ! <eos>
PREDICTION:
<sos> thank you ! <eos>
BLEU = 1.000000
------
INPUT:
<sos> 杜青林 说 . <eos>
REF:
<sos> du qinglin said . <eos>
PREDICTION:
<sos> du qinglin said . <eos>
BLEU = 1.000000
------
INPUT:
<sos> ( 王永霞 ) <eos>
REF:
<sos> ( by wang yongxia ) <eos>
PREDICTION:
<sos> ( by wang yongxia ) <eos>
BLEU = 1.000000
------
INPUT:
<sos> 朱鎔基 问 . <eos>
REF:
<sos> zhu ro

# Check inference result of test data

In [34]:
test_src, test_tgt, test_pairs = prepareData('data/test.txt', 'english', 'chinese')
test_src.trim()
test_tgt.trim()
print(random.choice(test_pairs))

Reading lines...
Read 10000 sentence pairs
Trimmed to 8572 sentence pairs
Counting words...
Counted words:
chinese 17294
english 13049
chinese:
Total words 17294
After Trimming 17290
Keep Ratio % 99.9768705909564
english:
Total words 13049
After Trimming 13045
Keep Ratio % 99.96934631006208
['曾宪梓 说 : " 我 的 立场 没有 改变 , 我 的 立场 是 坚决 拥护 中央政府 的 政策 , 这点 很 明确 .', 'tsang said : " i have not changed my position , which is to resolutely the central government policy .']


In [35]:
test_src.w2idx, test_src.idx2w, test_src.num = src.w2idx, src.idx2w, src.num
test_tgt.w2idx, test_tgt.idx2w, test_tgt.num = tgt.w2idx, tgt.idx2w, tgt.num
test_pairs.sort(key=lambda x: len(x[0].split()))

In [36]:
from infer_eval import bleu

test_range = 50
ave_bleu = 0
for test_idx in range(test_range):
    input_batches, input_lengths,\
        target_batches, target_lengths = random_batch(src,tgt,test_pairs,1,test_idx)
    _, pred = net.inference(input_batches[:,0].reshape(input_lengths[0].item(),1),input_lengths[0].reshape(1))
    inp = ' '.join([test_src.idx2w[t] for t in input_batches[:,0].cpu().numpy()])
    mt = ' '.join([test_tgt.idx2w[t] for t in pred if t!= PAD_idx])
    ref = ' '.join([test_tgt.idx2w[t] for t in target_batches[:,0].cpu().numpy() if t != PAD_idx])
    print('INPUT:\n' + inp)
    print('REF:\n' + ref)
    print('PREDICTION:\n' + mt)
    tmp_score = bleu([mt],[[ref]],4)
    ave_bleu += tmp_score
    print('BLEU = %f' % tmp_score)
    print("------")
print('Average BLEU = '+str(ave_bleu/test_range))

INPUT:
<sos> 时光 <unk> . <eos>
REF:
<sos> time really flies like a shuttle . <eos>
PREDICTION:
<sos> as a total of whole . <eos>
BLEU = 0.456227
------
INPUT:
<sos> 分担 经费 . <eos>
REF:
<sos> sharing the expense . <eos>
PREDICTION:
<sos> there are for funds . <eos>
BLEU = 0.516973
------
INPUT:
<sos> 中国 ! " <eos>
REF:
<sos> china ! " <eos>
PREDICTION:
<sos> " china ! ! " <eos>
BLEU = 0.516973
------
INPUT:
<sos> 提高 素质 迫在眉睫 <eos>
REF:
<sos> this is the only way to succeed . <eos>
PREDICTION:
<sos> raising the quality of raising their urgent . <eos>
BLEU = 0.459150
------
INPUT:
<sos> 权责 一体 . <eos>
REF:
<sos> integration of right and responsibility . <eos>
PREDICTION:
<sos> there are no . <eos>
BLEU = 0.402935
------
INPUT:
<sos> 显然 不是 . <eos>
REF:
<sos> obvious not . <eos>
PREDICTION:
<sos> obviously , not a . <eos>
BLEU = 0.555524
------
INPUT:
<sos> 你们 一定 能行 ! <eos>
REF:
<sos> you surely can do it ! <eos>
PREDICTION:
<sos> you you you you ! <eos>
BLEU = 0.572688
------
INPUT:
<sos> 王兆国 主

# User Input Translation

In [57]:
inputs = '国家主席 : 江泽民 .'
user_var, user_len = user_input(inputs, src)
_, pred = net.inference(user_var[:,0].reshape(user_len[0].item(),1),user_len[0].reshape(1))
inp = ' '.join([src.idx2w[t] for t in user_var[:,0].cpu().numpy()])
mt = ' '.join([tgt.idx2w[t] for t in pred if t!= PAD_idx])
print('INPUT:\n' + inp)
print('PREDICTION:\n' + mt)

INPUT:
<sos> 国家主席 : 江泽民 . <eos>
PREDICTION:
<sos> president jiang zemin . <eos>
