In [1]:
from __future__ import print_function, division
import numpy as np
import tensorflow as tf

num_epochs = 20
total_series_length = 50000
truncated_backprop_length = 15
state_size = 4
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//batch_size//truncated_backprop_length

In [2]:
def generateData(data_length):
    x = np.array(np.random.choice(2, data_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  # The first index changing slowest, subseries as rows
    y = y.reshape((batch_size, -1))

    return (x, y)

In [3]:
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

W2 = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b2 = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

# Unpack columns
inputs_series = tf.split(batchX_placeholder, num_or_size_splits=truncated_backprop_length, axis=1)
labels_series = tf.unstack(batchY_placeholder, axis=1)

# Forward passes
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.nn.static_rnn(cell, inputs_series, init_state)

logits_series = [tf.matmul(state, W2) + b2 for state in states_series] #Broadcasted addition
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) for logits, labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

In [4]:
data_length = 50000
x,y = generateData(data_length)

In [5]:
num_batches = data_length//batch_size//truncated_backprop_length

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        _current_cell_state = np.zeros((batch_size, state_size))
        _current_hidden_state = np.zeros((batch_size, state_size))

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    batchX_placeholder: batchX,
                    batchY_placeholder: batchY,
                    cell_state: _current_cell_state,
                    hidden_state: _current_hidden_state

                })

            _current_cell_state, _current_hidden_state = _current_state

            loss_list.append(_total_loss)

        print("Epoch ", epoch_idx, "Loss", _total_loss)
        
    # Generate some new data and see if it can recreate Y
    data_length = 100
    x,y = generateData(data_length)
    batchX = x[:,:truncated_backprop_length]
    batchY = y[:,:truncated_backprop_length]
    preds = sess.run(predictions_series, feed_dict={
        batchX_placeholder:batchX,
        init_state:_current_state})

plt.ioff()
plt.show()

Instructions for updating:
Use `tf.global_variables_initializer` instead.


<Figure size 432x288 with 0 Axes>

Epoch  0 Loss 0.004889549
Epoch  1 Loss 0.0013187904
Epoch  2 Loss 0.0007845162
Epoch  3 Loss 0.0005686203
Epoch  4 Loss 0.000452318
Epoch  5 Loss 0.0003792254
Epoch  6 Loss 0.00032909805
Epoch  7 Loss 0.00029256393
Epoch  8 Loss 0.00026483115
Epoch  9 Loss 0.00024398217
Epoch  10 Loss 0.00022341142
Epoch  11 Loss 0.00020974713
Epoch  12 Loss 0.00019917589
Epoch  13 Loss 0.00018659605
Epoch  14 Loss 0.00017736731
Epoch  15 Loss 0.00016689439
Epoch  16 Loss 0.00015812786
Epoch  17 Loss 0.00015087395
Epoch  18 Loss 0.00014450982
Epoch  19 Loss 0.00013894013


In [6]:
# See if "preds" matches "batchY". Ignore the first 3 indeces: Y has no "X" to be based on
idx = int(input("Enter an index value"))

Enter an index value8


In [7]:
print(batchY[:,idx])
print(np.argmax(preds[idx],axis=1))

[1 0 1 1 0]
[1 0 1 1 0]
