In [1]:
import mxnet as mx
from mxnet import autograd, gluon, nd
from mxnet.gluon import nn, rnn, Block
from mxnet.contrib import text

import collections
import datetime
import pickle
import numpy as np

In [2]:
epochs = 500
epoch_period = 1

learning_rate = 0.01
batch_size = 8

max_seq_len = 144
max_output_len = 144

encoder_num_layers = 2
decoder_num_layers = 4

encoder_drop_prob = 0.1
decoder_drop_prob = 0.1

encoder_hidden_dim = 256
decoder_hidden_dim = 512
alignment_dim = 512

ctx = mx.gpu(2)

In [4]:
X = []
Y = []
for i in range(1,24):
    d1 = np.load('train/%d.npy'%(i))
    d2 = np.load('train/%d.npy'%(i+1))
    for s_id in range(81):
        X.append(d1[:,s_id,:])
        Y.append(d2[:,s_id,:])
X = np.array(X)
Y = np.array(Y)
print(X.shape, Y.shape)
dataset = gluon.data.ArrayDataset(nd.array(X,ctx=ctx), nd.array(Y,ctx=ctx))

(1863, 144, 2) (1863, 144, 2)


-------------------------------------------------

In [19]:
class Encoder(Block):
    def __init__(self, input_dim, hidden_dim, num_layers, drop_prob, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        with self.name_scope():
            self.fc = nn.Dense(input_dim, hidden_dim, flatten=True)
            self.rnn = rnn.GRU(hidden_dim, num_layers, dropout=drop_prob, input_size=hidden_dim)
    
    def forward(self, inputs, state):
        # rnn input shape : seq_len * bs * 2
        inputs = self.fc(inputs).swapaxes(0,1)
        output, state = self.rnn(inputs, state)
        return output, state

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

In [20]:
class Decoder(Block):
    """含注意力机制的解码器"""
    def __init__(self, hidden_dim, output_dim, num_layers, max_seq_len,
                 drop_prob, alignment_dim, encoder_hidden_dim, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.max_seq_len = max_seq_len
        self.encoder_hidden_dim = encoder_hidden_dim
        self.hidden_size = hidden_dim
        self.num_layers = num_layers
        with self.name_scope():
            #self.embedding = nn.Embedding(output_dim, hidden_dim)
            self.dropout = nn.Dropout(drop_prob)
            # 注意力机制。
            self.attention = nn.Sequential()
            with self.attention.name_scope():
                self.attention.add(nn.Dense(
                    alignment_dim, in_units=hidden_dim + encoder_hidden_dim,
                    activation="tanh", flatten=False))
                self.attention.add(nn.Dense(1, in_units=alignment_dim,
                                            flatten=False))

            self.rnn = rnn.GRU(hidden_dim, num_layers, dropout=drop_prob,
                               input_size=hidden_dim)
            self.out = nn.Dense(output_dim, in_units=hidden_dim, flatten=False)
            self.rnn_concat_input = nn.Dense(
                hidden_dim, in_units=hidden_dim + encoder_hidden_dim,
                flatten=False)

    def forward(self, cur_input, state, encoder_outputs):
        # 当RNN为多层时，取最靠近输出层的单层隐含状态。
        single_layer_state = [state[0][-1].expand_dims(0)]
        encoder_outputs = encoder_outputs.reshape((self.max_seq_len, -1,
                                                   self.encoder_hidden_dim))
        # single_layer_state尺寸: [(1, batch_size, decoder_hidden_dim)]
        # hidden_broadcast尺寸: (max_seq_len, batch_size, decoder_hidden_dim)
        hidden_broadcast = nd.broadcast_axis(single_layer_state[0], axis=0,
                                             size=self.max_seq_len)

        # encoder_outputs_and_hiddens尺寸:
        # (max_seq_len, batch_size, encoder_hidden_dim + decoder_hidden_dim)
        encoder_outputs_and_hiddens = nd.concat(encoder_outputs,
                                                hidden_broadcast, dim=2)

        # energy尺寸: (max_seq_len, batch_size, 1)
        energy = self.attention(encoder_outputs_and_hiddens)

        # batch_attention尺寸: (batch_size, 1, max_seq_len)
        batch_attention = nd.softmax(energy, axis=0).transpose(
            (1, 2, 0))

        # batch_encoder_outputs尺寸: (batch_size, max_seq_len, encoder_hidden_dim)
        batch_encoder_outputs = encoder_outputs.swapaxes(0, 1)

        # decoder_context尺寸: (batch_size, 1, encoder_hidden_dim)
        decoder_context = nd.batch_dot(batch_attention, batch_encoder_outputs)

        # cur_input尺寸: (batch_size,)
        # input_and_context尺寸: (batch_size, 1, encoder_hidden_dim + decoder_hidden_dim)
        input_and_context = nd.concat(nd.expand_dims(self.embedding(cur_input), axis=1),
                                      decoder_context, dim=2)
        # concat_input尺寸: (1, batch_size, decoder_hidden_dim)
        concat_input = self.rnn_concat_input(input_and_context).reshape((1, -1, 0))
        concat_input = self.dropout(concat_input)

        # 当RNN为多层时，用单层隐含状态初始化各个层的隐含状态。
        state = [nd.broadcast_axis(single_layer_state[0], axis=0,
                                   size=self.num_layers)]

        output, state = self.rnn(concat_input, state)
        output = self.dropout(output)
        output = self.out(output).reshape((-3, -1))
        # output尺寸: (batch_size, output_size)
        return output, state

    def begin_state(self, *args, **kwargs):
        return self.rnn.begin_state(*args, **kwargs)

In [21]:
class DecoderInitState(Block):
    """解码器隐含状态的初始化"""
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim, **kwargs):
        super(DecoderInitState, self).__init__(**kwargs)
        with self.name_scope():
            self.dense = nn.Dense(decoder_hidden_dim,
                                  in_units=encoder_hidden_dim,
                                  activation="tanh", flatten=False)

    def forward(self, encoder_state):
        return [self.dense(encoder_state)]

In [22]:
encoder = Encoder(144, encoder_hidden_dim, encoder_num_layers,
                  encoder_drop_prob)
decoder = Decoder(decoder_hidden_dim, 2,
                  decoder_num_layers, max_seq_len, decoder_drop_prob,
                  alignment_dim, encoder_hidden_dim)
decoder_init_state = DecoderInitState(encoder_hidden_dim, decoder_hidden_dim)

TypeError: unsupported operand type(s) for +: 'int' and 'str'

In [9]:
encoder.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
decoder.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
decoder_init_state.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

softmax_cross_entropy = gluon.loss.L1Loss()

In [14]:
x.shape

(8, 144, 2)

In [10]:
encoder_optimizer = gluon.Trainer(encoder.collect_params(), 'adam',
                                  {'learning_rate': learning_rate})
decoder_optimizer = gluon.Trainer(decoder.collect_params(), 'adam',
                                  {'learning_rate': learning_rate})
decoder_init_state_optimizer = gluon.Trainer(
    decoder_init_state.collect_params(), 'adam',
    {'learning_rate': learning_rate})

prev_time = datetime.datetime.now()
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)

total_loss = 0.0
iter_times = 0

for epoch in range(1, epochs + 1):
    for x, y in data_iter:
        real_batch_size = x.shape[0]
        with autograd.record():
            loss = nd.array([0], ctx=ctx)
            valid_length = nd.array([0], ctx=ctx)
            encoder_state = encoder.begin_state(
                func=mx.nd.zeros, batch_size=real_batch_size, ctx=ctx)
            encoder_outputs, encoder_state = encoder(x, encoder_state)

            # encoder_outputs尺寸: (max_seq_len, encoder_hidden_dim)
            encoder_outputs = encoder_outputs.flatten()
            # 解码器的第一个输入为BOS字符。
            decoder_input = nd.array([0] * real_batch_size,
                                     ctx=ctx)
            mask = nd.ones(shape=(real_batch_size,), ctx=ctx)
            decoder_state = decoder_init_state(encoder_state[0])
            for i in range(max_seq_len):
                decoder_output, decoder_state = decoder(
                    decoder_input, decoder_state, encoder_outputs)
                # 解码器使用当前时刻的预测结果作为下一时刻的输入。
                decoder_input = decoder_output
                loss = loss + softmax_cross_entropy(decoder_output, y[:, i])
            loss = loss / 144
            loss = nd.mean(loss)
        loss.backward()

        encoder_optimizer.step(1)
        decoder_optimizer.step(1)
        decoder_init_state_optimizer.step(1)

        total_loss += loss.asscalar() / max_seq_len
        iter_times += 1
        
    if epoch % epoch_period == 0 or epoch == 1:
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = 'Time %02d:%02d:%02d' % (h, m, s)
        if epoch == 1:
            print_loss_avg = total_loss / len(data_iter)
        else:
            print_loss_avg = total_loss / epoch_period / len(data_iter)
        loss_str = 'Epoch %d, Loss %f, ' % (epoch, print_loss_avg)
        print(loss_str + time_str)
        if epoch != 1:
            total_loss = 0.0
        prev_time = cur_time

MXNetError: Shape inconsistent, Provided = [789504], inferred shape=(594432,)

In [None]:
def test(x, y):
    nd_x = x
    x = [int(_) for _ in x.asnumpy().tolist()]
    y = [int(_) for _ in y.asnumpy().tolist()]
    print('Input: ',' '.join(input_vocab.to_tokens(x)).split('<eos>')[0])
    #print('Expect: ',' '.join(output_vocab.to_tokens(y)))
    
    encoder_state = encoder.begin_state(func=mx.nd.zeros, batch_size=1, ctx=ctx)
    encoder_outputs, encoder_state = encoder(nd_x.expand_dims(0), encoder_state)
    encoder_outputs = encoder_outputs.flatten()
    
    decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
    decoder_state = decoder_init_state(encoder_state[0])
    output_tokens = []
    
    for _ in range(max_output_len):
        decoder_output, decoder_state = decoder(
            decoder_input, decoder_state, encoder_outputs)
        pred_i = int(decoder_output.argmax(axis=1).asnumpy()[0])
        if pred_i == output_vocab.token_to_idx[EOS]:
            break
        else:
            output_tokens.append(output_vocab.idx_to_token[pred_i])
        decoder_input = nd.array([pred_i], ctx=ctx)
    print('Output:', ' '.join(output_tokens))

In [None]:
def test(x):
    print('Input: ', x)
    input_tokens = x.split(' ') + [EOS]
    while len(input_tokens) < max_seq_len:
        input_tokens.append(PAD)
    inputs = nd.array(input_vocab.to_indices(input_tokens), ctx=ctx)
    encoder_state = encoder.begin_state(func=mx.nd.zeros, batch_size=1, ctx=ctx)
    encoder_outputs, encoder_state = encoder(inputs.expand_dims(0), encoder_state)
    encoder_outputs = encoder_outputs.flatten()
    
    decoder_state = decoder_init_state(encoder_state[0])
    decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
    output_tokens = []
    
    for _ in range(max_output_len):
        decoder_output, decoder_state = decoder(
            decoder_input, decoder_state, encoder_outputs)
        pred_i = int(decoder_output.argmax(axis=1).asnumpy()[0])
        if pred_i == output_vocab.token_to_idx[EOS]:
            break
        else:
            output_tokens.append(output_vocab.idx_to_token[pred_i])
        decoder_input = nd.array([pred_i], ctx=ctx)
    print('Output: ', ' '.join(output_tokens))
    return ' '.join(output_tokens)

In [None]:
with open('test_data.txt','r') as f:
    test_data = f.readlines()
test_data = [data[:-1] for data in test_data]
ans = []
for data in test_data:
    ans.append(test(data))
with open('ans-none.pickle','wb') as f:
    pickle.dump(ans,f)

In [None]:
for i in range(0,2890):
    test(X[i],Y[i])
    print('\n')

In [None]:
eval_post_resp = []
for i in range(30):
    eval_post_resp.append([' '.join(input_seqs[i]), ' '.join(output_seqs[i])])
beam_size = 1
beam_search_test(encoder, decoder, decoder_init_state, eval_post_resp, ctx, max_seq_len)

In [None]:
def beam_search_test(encoder, decoder, decoder_init_state, eval_post_resp, ctx, max_seq_len):
    for p_r in eval_post_resp:
        print('Input: ', p_r[0])
        input_tokens = p_r[0].split(' ') + [EOS]
        while len(input_tokens) < max_seq_len:
            input_tokens.append(PAD)
        inputs = nd.array(input_vocab.to_indices(input_tokens), ctx = ctx)
        encoder_state = encoder.begin_state(func=mx.nd.zeros, batch_size=1, ctx=ctx)
        encoder_outputs, encoder_state = encoder(inputs.expand_dims(0), encoder_state)
        encoder_outputs = encoder_outputs.flatten()
        
        #decoder_input = nd.array([output_vocab.token_to_idx[BOS]], ctx=ctx)
        decoder_state = decoder_init_state(encoder_state[0])
        
        candidates = [[BOS] for _ in range(beam_size)]
        probs = [0.0 for _ in range(beam_size)]
        
        for _ in range(max_output_len):
            tmp = []
            tmp_cand = []
            for k in range(beam_size):
                decoder_input = nd.array([output_vocab.token_to_idx[candidates[k][-1]]], ctx=ctx)
                decoder_output, decoder_state = decoder(decoder_input, decoder_state, encoder_outputs)
                
                pred_score, pred_index = decoder_output.topk(ret_typ='both',k=beam_size)
                #pred_i = int(pred_index[0].asnumpy()[0])
                for j in range(beam_size):
                    tmp.append(probs[k] + pred_score[0].asnumpy()[j])
                    tmp_cand.append(candidates[k] + [output_vocab.idx_to_token[int(pred_index[0].asnumpy()[j])]])
            top_k_idx = np.argsort(tmp)[-beam_size:]
            for k in range(beam_size):
                candidates[k] = tmp_cand[top_k_idx[k]]
                probs[k] = tmp[top_k_idx[k]]
                #print(' '.join(candidates[k]))
        #print(probs)
        top_idx = np.argsort(probs).tolist()
        #print(top_idx)
        print('Output: ')
        for idx in top_idx:
            print(' '.join(candidates[idx]), probs[idx])
        print('Expect:', p_r[1], '\n')

In [None]:
encoder.collect_params().save('encoder-bless-L118.params')
decoder.collect_params().save('decoder-bless-L118.params')
decoder_init_state.collect_params().save('decoder-init-state-bless-L118.params')