In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

## XLSTM 网络搭建

[paper](https://arxiv.org/abs/1709.08073)

说明：

原文是建立三个模型，对三个模型的中间一个LSTM层进行特征交叉，每个模型的输入是不同的（体重，步数，睡眠）。本文为了简便，三个模型输入的是相同的数据。即都是 手写字体数据。

网络结构图：


In [2]:
file_dir = '/Users/wangruidong/Documents/MachineLearning/Dataset/MNIST/'
mnist = input_data.read_data_sets(file_dir, one_hot=True)
train_img = mnist.train.images
train_labels = mnist.train.labels
print(train_img.shape)
print(train_labels.shape)
print('Load data finish.')

Extracting /Users/wangruidong/Documents/MachineLearning/Dataset/MNIST/train-images-idx3-ubyte.gz
Extracting /Users/wangruidong/Documents/MachineLearning/Dataset/MNIST/train-labels-idx1-ubyte.gz
Extracting /Users/wangruidong/Documents/MachineLearning/Dataset/MNIST/t10k-images-idx3-ubyte.gz
Extracting /Users/wangruidong/Documents/MachineLearning/Dataset/MNIST/t10k-labels-idx1-ubyte.gz
(55000, 784)
(55000, 10)
Load data finish.


In [4]:
input_dim = 28
time_step = 28
output_dim = 10
batch_size = 15
n_hidden = 128 # LSTM 输出维度
n_epoch = 20 # 总训练轮数
n_batches = 10
display_step = 2

POS_1 = int(n_hidden / 3)
POS_2 = int(POS_1 * 2)
POS_3 = int(n_hidden - POS_1 - POS_2)

with tf.name_scope('parameters'):
    weights = {
        'w_out': tf.Variable(tf.random_normal([n_hidden, output_dim], stddev=0.1), name='w_out'),
    }

    bias = {
        'b_out': tf.Variable(tf.random_normal([output_dim], stddev=0.1), name='b_out'),
    }
with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None, time_step, input_dim], name='x_input')
    y = tf.placeholder(tf.float32, [None, output_dim], name='y_input')

In [5]:
def dynamic_rnn(x, weights, bias):
    # initial x shape = [batch_size, time_step, features]
    # ==========> First Layer
    with tf.name_scope('first_layer'):
        with tf.variable_scope('layer1_m1') as scope:
            rnn1_m1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o1_m1, s1_m1 = tf.nn.dynamic_rnn(rnn1_m1,x, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer1_m2') as scope:
            rnn1_m2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o1_m2, s1_m2 = tf.nn.dynamic_rnn(rnn1_m2,x, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer1_m3') as scope:
            rnn1_m3 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o1_m3, s1_m3 = tf.nn.dynamic_rnn(rnn1_m3,x, dtype=tf.float32, time_major=False)
    
    with tf.name_scope('cross_data'):
        with tf.name_scope('split_1'):
            split1_o1_m1, split2_o1_m1, split2_o1_m1 = tf.split(o1_m1, [POS_1, POS_2, POS_3], axis=2, name='split_1')
        with tf.name_scope('split_2'):
            split1_o1_m2, split2_o1_m2, split3_o1_m2 = tf.split(o1_m2, [POS_1, POS_2, POS_3], axis=2, name='split_2')
        with tf.name_scope('split_3'):
            split1_o1_m3, split2_o1_m3, split3_o1_m3 = tf.split(o1_m3, [POS_1, POS_2, POS_3], axis=2, name='split_3')
        with tf.name_scope('concat_1'):
            new_o1_m1 = tf.concat([split1_o1_m1, split2_o1_m2, split3_o1_m3], axis=2, name='new_o1_m1')
        with tf.name_scope('concat_2'):
            new_o1_m2 = tf.concat([split1_o1_m2, split2_o1_m3, split2_o1_m1], axis=2, name='new_o1_m2')
        with tf.name_scope('concat_3'):
            new_o1_m3 = tf.concat([split1_o1_m3, split2_o1_m1, split3_o1_m2], axis=2, name='new_o1_m3')
    
    # ==========> Second Layer Cross 
    with tf.name_scope('second_layer'):
        with tf.variable_scope('layer2_m1') as scope:
            rnn2_m1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o2_m1, s2_m1 = tf.nn.dynamic_rnn(rnn2_m1,new_o1_m1, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer2_m2') as scope:
            rnn2_m2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o2_m2, s2_m2 = tf.nn.dynamic_rnn(rnn2_m2,new_o1_m2, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer2_m3') as scope:
            rnn2_m3 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o2_m3, s2_m3 = tf.nn.dynamic_rnn(rnn2_m3,new_o1_m3, dtype=tf.float32, time_major=False)
    
    # ==========> Third Layer
    with tf.name_scope('third_layer'):
        with tf.variable_scope('layer3_m1') as scope:
            rnn3_m1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o3_m1, s3_m1 = tf.nn.dynamic_rnn(rnn3_m1,o2_m1, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer3_m2') as scope:
            rnn3_m2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o3_m2, s3_m2 = tf.nn.dynamic_rnn(rnn3_m2,o2_m2, dtype=tf.float32, time_major=False)

        with tf.variable_scope('layer3_m3') as scope:
            rnn3_m3 = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden, forget_bias=1.0, reuse=tf.AUTO_REUSE)
            o3_m3, s3_m3 = tf.nn.dynamic_rnn(rnn3_m3,o2_m3, dtype=tf.float32, time_major=False)
    
    with tf.name_scope('merge_3_model'):
        out = tf.add(o3_m1[:,-1,:], o3_m2[:,-1,:])
        out = tf.add(out, o3_m3[:,-1,:])
    
    with tf.name_scope('full_connection'):
        out = tf.add(tf.matmul(out, weights['w_out']), bias['b_out'])
    return out

In [6]:
# ==========> loss and accuracy
with tf.name_scope('predict'):
    pre = dynamic_rnn(x, weights, bias)
with tf.name_scope('loss'):
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=pre))
with tf.name_scope('train'):
    optm = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)

with tf.name_scope('accuracy'):
    _equal = tf.equal(tf.argmax(pre, axis=1), tf.argmax(y, axis=1))
    accuracy = tf.reduce_mean(tf.cast(_equal, tf.float32))

sess = tf.Session()
train_writer = tf.summary.FileWriter(logdir='logs/', graph=sess.graph)
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
    avg_loss, avg_acc = 0.0, 0.0
    for i in range(n_batches):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x = batch_x.reshape(batch_size, time_step, input_dim)
        feed = {x:batch_x, y:batch_y}
        sess.run(optm, feed_dict=feed)
        avg_loss += sess.run(cost, feed_dict=feed)
        avg_acc += sess.run(accuracy, feed_dict=feed)
    avg_loss /= n_batches
    avg_acc /= n_batches
    if epoch % display_step == 0 and epoch != 0:
        print('Epoch %d/%d \t Train acc = %.4f loss = %.4f' % (epoch, n_epoch, avg_acc, avg_loss))
print('Finish')

Epoch 2/20 	 Train acc = 0.1333 loss = 2.2831
Epoch 4/20 	 Train acc = 0.1600 loss = 2.2750
Epoch 6/20 	 Train acc = 0.1000 loss = 2.2829
Epoch 8/20 	 Train acc = 0.2867 loss = 2.2678
Epoch 10/20 	 Train acc = 0.2933 loss = 2.2573
Epoch 12/20 	 Train acc = 0.2800 loss = 2.2444
Epoch 14/20 	 Train acc = 0.3733 loss = 2.2378
Epoch 16/20 	 Train acc = 0.3733 loss = 2.2122
Epoch 18/20 	 Train acc = 0.3400 loss = 2.2040
Finish
