In [1]:
from __future__ import unicode_literals, print_function, division

# import basic lib
import re
import math
import random
import string
import unicodedata
from io import open

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_

# import loss func
import masked_cross_entropy

# check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SOS_idx = 0
EOS_idx = 1
UNK_idx = 2
PAD_idx = 3

USE_CUDA = True

class Preprocessor:
    '''
    class for preprocessing
    '''
    def __init__(self, name):
        '''
        initialize vocab and counter
        '''
        self.name = name
        self.w2idx = {"<sos>" : 0, "<eos>" : 1, "<unk>" : 2, "<pad>" : 3}
        self.counter = {}
        self.idx2w = {0: "<sos>", 1: "<eos>", 2:"<unk>", 3:"<pad>"}
        self.num = 4

    def SentenceAdder(self, sentence):
        '''
        Add a sentence to dataset
        '''
        for word in sentence.split(' '):
            self.WordAdder(word)

    def WordAdder(self, word):
        '''
        Add single word to dataset and update vocab and counter
        '''
        if word in self.w2idx:
            self.counter[word] += 1
        else:
            self.w2idx[word] = self.num
            self.counter[word] = 1
            self.idx2w[self.num] = word
            self.num += 1
            
    def trim(self, min_count=5):
        '''
        Trim to remove non-frequent word
        '''
        keep = []
        for k, v in self.counter.items():
            if v >= min_count: keep.append(k)
        print(self.name+':')
        print('Total words', len(self.w2idx))
        print('After Trimming', len(keep))
        print('Keep Ratio %', 100 * len(keep) / len(self.w2idx))
        self.w2idx = {"<sos>" : 0, "<eos>" : 1, "<unk>" : 2, "<pad>" : 3}
        self.counter = {}
        self.idx2w = {0: "<sos>", 1: "<eos>", 2:"<unk>", 3:"<pad>"}
        self.num = 4
        for w in keep:
            self.WordAdder(w)

In [3]:
def Uni2Ascii(s):
    '''
    transfer from unicode to ascii
    '''
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                    if unicodedata.category(c) != 'Mn')

def StrCleaner(s):
    '''
    trim, delete non-letter and lowercase string
    '''
    s = Uni2Ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def DataReader(path, lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open(path, encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    #pairs = [[StrCleaner(s) for s in l.split('<------>')] for l in lines]
    pairs = [[s.lower() for s in l.split('<------>')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Preprocessor(lang2)
        output_lang = Preprocessor(lang1)
    else:
        input_lang = Preprocessor(lang1)
        output_lang = Preprocessor(lang2)

    return input_lang, output_lang, pairs

In [4]:
MIN_LENGTH = 10
MAX_LENGTH = 50

def filterPair(p):
    '''
    Filter to get expected pairs with specific length
    '''
    return MIN_LENGTH <= len(p[0].split(' ')) <= MAX_LENGTH and \
        MIN_LENGTH <= len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [5]:
def prepareData(path, lang1, lang2, reverse=True):
    input_lang, output_lang, pairs = DataReader(path, lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.SentenceAdder(pair[0])
        output_lang.SentenceAdder(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.num)
    print(output_lang.name, output_lang.num)
    return input_lang, output_lang, pairs


src, tgt, pairs = prepareData('data/train.txt', 'english', 'chinese')
src.trim()
tgt.trim()
print(random.choice(pairs))

Reading lines...
Read 1800000 sentence pairs
Trimmed to 1428593 sentence pairs
Counting words...
Counted words:
chinese 329850
english 209099
chinese:
Total words 329850
After Trimming 105158
Keep Ratio % 31.880551765954223
english:
Total words 209099
After Trimming 57225
Keep Ratio % 27.367419260732955
['过 成熟 阶段 碳酸盐 生油 岩中 烃类 产物 的 相态 已由油 相 演化 到 湿气 、 干 气相 。 ', 'the phase state of hydrocarbon formed in carbonate source rock at overmatured stage change from oil into wet and dry gas .']


In [6]:
def sentence2idx(preprocessor, sentence):
    '''
    Read sentence and translate into word index plus eos
    '''
    return [SOS_idx] + [preprocessor.w2idx[w] if w in preprocessor.w2idx \
            else UNK_idx for w in sentence.split(' ')] + [EOS_idx]

def pad(seq, max_len):
    '''
    Add padding to sentence with different length
    '''
    seq += [PAD_idx for i in range(max_len - len(seq))]
    return seq

def random_batch(src, tgt, batch_size=5):
    '''
    Randomly generate batch data
    '''
    inputs, target = [], []
    
    # Choose batch randomly
    for _ in range(batch_size):
        pair = random.choice(pairs)
        inputs.append(sentence2idx(src, pair[0]))
        target.append(sentence2idx(tgt, pair[1]))
        
    # Sort by length
    seq_pairs = sorted(zip(inputs, target), key=lambda p: len(p[0]), reverse=True)
    inputs, target = zip(*seq_pairs)
    
    # Obtain length of each sentence and pad
    input_lens = [len(s) for s in inputs]
    input_max = max(input_lens)
    input_padded = [pad(s, input_max) for s in inputs]
    target_lens = [len(s) for s in target]
    target_max = max(target_lens)
    target_padded = [pad(s, target_max) for s in target]

    # Create Variable
    if USE_CUDA:
        input_vars = Variable(torch.LongTensor(input_padded).cuda()).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens).cuda())
        target_vars = Variable(torch.LongTensor(target_padded).cuda()).transpose(0, 1)
        target_lens = Variable(torch.LongTensor(target_lens).cuda())
    else:
        input_vars = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
        input_lens = Variable(torch.LongTensor(input_lens))
        target_vars = Variable(torch.LongTensor(target_padded)).transpose(0, 1)
        target_lens = Variable(torch.LongTensor(target_lens))

    return input_vars, input_lens, target_vars, target_lens

In [7]:
class Encoder(nn.Module):
    '''
    Define encoder and forward process
    '''
    def __init__(self, dim_input, dim_embed, dim_hidden, num_layers, dropout):
        super(Encoder, self).__init__()
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_embed = dim_embed
        self.embed = nn.Embedding(dim_input, dim_embed)
        self.cell = nn.GRU(dim_embed, dim_hidden, 
                          num_layers, dropout=dropout, 
                          bidirectional=True)
        
    def forward(self, inputs, inputs_lens, hidden=None):
        '''
        We need to sum the outputs since bi-diretional is used
        '''
        embedded = self.embed(inputs)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, inputs_lens)
        outputs, hidden = self.cell(packed, hidden)
        outputs, output_lengths = nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.dim_hidden] + \
                    outputs[:, :, self.dim_hidden:]
        return outputs, hidden


class Attention(nn.Module):
    '''
    Define attention mechanism
    '''
    def __init__(self, dim_hidden):
        super(Attention, self).__init__()
        self.dim_hidden = dim_hidden
        # 2*dim_hidden is needed since bi-direction is used
        self.attn = nn.Linear(2*self.dim_hidden, dim_hidden)
        self.v = nn.Parameter(torch.rand(dim_hidden))
        stdv = 1. / math.sqrt(self.v.size(0))
        self.v.data.uniform_(-stdv, stdv)

    def forward(self, hidden, encoder_outputs):
        timestep = encoder_outputs.size(0)
        h = hidden.repeat(timestep, 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(0, 1)
        scores = self.score(h, encoder_outputs)
        return F.relu(scores).unsqueeze(1)

    def score(self, hidden, encoder_outputs):
        e = F.softmax(self.attn(torch.cat([hidden, encoder_outputs], 2)),dim=1)
        e = e.transpose(1, 2)
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)
        e = torch.bmm(v, e)
        return e.squeeze(1)


class Decoder(nn.Module):
    '''
    Define decoder with attention
    '''
    def __init__(self, dim_embed, dim_hidden, dim_output, num_layers, dropout):
        super(Decoder, self).__init__()
        self.dim_embed = dim_embed
        self.dim_hidden = dim_hidden
        self.dim_output = dim_output
        self.num_layers = num_layers

        self.embed = nn.Embedding(dim_output, dim_embed)
        self.dropout = nn.Dropout(dropout, inplace=True)
        self.attention = Attention(dim_hidden)
        self.cell = nn.GRU(dim_hidden + dim_embed, dim_hidden,
                          num_layers, dropout=dropout)
        self.out = nn.Linear(2*dim_hidden, dim_output)

    def forward(self, inputs, last_hidden, encoder_outputs):
        # Get the embedding of the current input word (last output word)
        embedded = self.embed(inputs).unsqueeze(0)  # (1,B,N)
        embedded = self.dropout(embedded)
        # Calculate attention weights and apply to encoder outputs
        attn_weights = self.attention(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # (B,1,N)
        context = context.transpose(0, 1)  # (1,B,N)
        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat([embedded, context], 2)
        #print(embedded.size())
        #print(context.size())
        #print(rnn_input.size())
        #print(last_hidden.size())
        output, hidden = self.cell(rnn_input, last_hidden)
        output = output.squeeze(0)  # (1,B,N) -> (B,N)
        context = context.squeeze(0)
        # For Debug
        #print(output.size())
        #print(context.size())
        torch.cat([output, context], 1)
        output = self.out(torch.cat([output, context], 1))
        output = F.log_softmax(output, dim=1)
        return output, hidden, attn_weights

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_len, tgt, tgt_len, teacher_forcing_ratio=0.5):
        batch_size = src.size(1)
        max_len = tgt.size(0)
        vocab_size = self.decoder.dim_output
        outputs = Variable(torch.zeros(max_len, batch_size, vocab_size).cuda())
        # for debug
        #print(src.size())
        #print(src_len.size())
        encoder_output, hidden = self.encoder(src, src_len)
        hidden = hidden[:self.decoder.num_layers]
        # Put <sos> at first position
        output = Variable(tgt.data[0, :])
        for t in range(1, max_len):
            output, hidden, attn_weights = self.decoder(
                    output, hidden, encoder_output)
            outputs[t] = output
            # Randomly choose whether to use teacher force or not
            is_teacher = random.random() < teacher_forcing_ratio
            top1 = output.data.max(1)[1]
            output = Variable(tgt.data[t].cuda() if is_teacher else top1.cuda())
        return outputs
    
    def inference(self, src, src_len, max_len = MAX_LENGTH):
        pred_idx = []
        batch_size = src.size(1)
        vocab_size = self.decoder.dim_output
        outputs = Variable(torch.zeros(max_len, batch_size, vocab_size)).cuda()

        encoder_output, hidden = self.encoder(src, src_len)
        hidden = hidden[:self.decoder.num_layers]
        # Put <sos> at first position
        output = Variable(src.data[0, :])
        for t in range(1, max_len):
            output, hidden, attn_weights = self.decoder(
                    output, hidden, encoder_output)
            outputs[t] = output
            top1 = output.data.max(1)[1]
            pred_idx.append(top1.item())
            output = Variable(top1).cuda()
            if top1 == EOS_idx: break
        return outputs, pred_idx

In [8]:
batch_size = 3
hidden_size = 5
embed_size = 10
n_layers = 4
encoder_test = Encoder(src.num, embed_size, hidden_size, n_layers, dropout=0.5)
decoder_test = Decoder(embed_size, hidden_size, tgt.num, n_layers, dropout=0.5)

In [9]:
net = Seq2Seq(encoder_test,decoder_test).cuda()
opt = optim.Adam(net.parameters(),lr=0.01)
print(net)

Seq2Seq(
  (encoder): Encoder(
    (embed): Embedding(105162, 10)
    (cell): GRU(10, 5, num_layers=4, dropout=0.5, bidirectional=True)
  )
  (decoder): Decoder(
    (embed): Embedding(57229, 10)
    (dropout): Dropout(p=0.5, inplace)
    (attention): Attention(
      (attn): Linear(in_features=10, out_features=5, bias=True)
    )
    (cell): GRU(15, 5, num_layers=4, dropout=0.5)
    (out): Linear(in_features=10, out_features=57229, bias=True)
  )
)


In [10]:
grad_clip = 10

for step in range(1,500):
    total_loss = 0
    input_batches, input_lengths,\
        target_batches, target_lengths = random_batch(src,tgt,batch_size)
    opt.zero_grad()
    output = net(input_batches, input_lengths, target_batches, target_lengths)

    # For Debug
    #print('target lengths', target_lengths)
    
    loss = masked_cross_entropy.compute_loss(
        output.transpose(0, 1).contiguous(),
        target_batches.transpose(0, 1).contiguous(),
        target_lengths
    )
    print('loss = ', loss.item())
    
  
    
    clip_grad_norm_(net.parameters(), grad_clip)
    loss.backward()
    opt.step()

loss =  10.94837760925293
loss =  10.91183853149414
loss =  10.919384002685547
loss =  10.906928062438965
loss =  10.885910987854004
loss =  10.951333999633789
loss =  10.79529094696045
loss =  10.838907241821289
loss =  10.810870170593262
loss =  10.7531156539917
loss =  10.80179214477539
loss =  10.72758674621582
loss =  10.577681541442871
loss =  10.544478416442871
loss =  10.404194831848145
loss =  10.354537963867188
loss =  10.25487995147705
loss =  10.016528129577637
loss =  9.944005966186523
loss =  9.591464042663574
loss =  9.581192016601562
loss =  9.115730285644531
loss =  9.422119140625
loss =  8.920899391174316
loss =  9.089991569519043
loss =  8.752769470214844
loss =  9.39646053314209
loss =  8.103239059448242
loss =  8.130936622619629
loss =  8.358530044555664
loss =  7.665470123291016
loss =  8.005951881408691
loss =  8.767878532409668
loss =  8.706929206848145
loss =  8.221593856811523
loss =  8.924992561340332
loss =  8.464829444885254
loss =  8.282978057861328
loss =

KeyboardInterrupt: 

In [11]:
_, pred = net.inference(input_batches[:,1].reshape(input_lengths[0].item(),1),input_lengths[0].reshape(1))

In [12]:
input_batches

tensor([[    0,     0,     0],
        [ 3190,  2264,   508],
        [17078,  1687,  1334],
        [  298,    87,   174],
        [  131,  8199,  7091],
        [ 1485,  1198,    87],
        [   21, 22599,    18],
        [ 1906,  7294,    18],
        [    7,   204,  6518],
        [  970, 10458, 20391],
        [ 3489,  7877, 73512],
        [  576,    64, 32072],
        [ 2399,    21,    21],
        [    2, 14009,    38],
        [    7,  7877,  1818],
        [ 2013,   430,     2],
        [35705,  4679,    95],
        [  221,    21,    18],
        [   38,   407,     1],
        [ 2475,    95,     3],
        [ 9731,    18,     3],
        [   21,     1,     3],
        [14859,     3,     3],
        [11183,     3,     3],
        [ 1110,     3,     3],
        [   37,     3,     3],
        [   18,     3,     3],
        [    1,     3,     3]], device='cuda:0')

In [13]:
input_lengths

tensor([28, 22, 19], device='cuda:0')

In [14]:
' '.join([tgt.idx2w[t] for t in pred])

'the the the the the the the the the the the the . . . . . . <eos>'

In [15]:
pred

[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 10, 10, 10, 10, 10, 10, 1]

In [17]:
tgt.idx2w[10]

'.'