# Recurrent Neural Network Example

Build a recurrent neural network (LSTM) with TensorFlow.

- Author: Aymeric Damien
- Adapted by: Hadi Daneshmand
- Project: https://github.com/aymericdamien/TensorFlow-Examples/

We would like to implement the following simple RNN network in tensorflow 


$$ h_t = \alpha(W h_{t-1} + U x_{t}+ b)$$

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png" alt="nn" style="width: 600px;"/>

In [2]:
# Source:  exercise 9 of ETH deep learning course 2017
import tensorflow as tf
import numpy as np
import os

data_dim, num_hid = 50, 20
sequence_length = 10 # this is hard-coded, but would vary between training samples in a real program

with tf.Session().as_default() as sess:
    W = tf.Variable(np.random.normal(size=(num_hid, num_hid)).astype(np.float32), name='W')
    U = tf.Variable(np.random.normal(size=(num_hid, data_dim)).astype(np.float32), name='U')
    b = tf.Variable(np.zeros(num_hid).astype(np.float32))

    init = tf.global_variables_initializer()
    sess.run(init)

    def step(i, h, x):
        xi = tf.gather(x, i) # get ith sample in the sequene
        Wh = tf.matmul(W, tf.expand_dims(h, 1)) # hidden transititon
        Ux = tf.matmul(U, tf.expand_dims(xi, 1)) # include sample
        s = tf.squeeze(Wh + Ux) + b # sum contributions of hidden, sample and bias
        h = tf.nn.relu(s) # nonlinearity
        return i + 1, h, x # these outputs will be given to the stopping condition as well as the next step

    h_init = np.zeros(num_hid).astype(np.float32) # initial hidden state
    data_point = np.random.rand(sequence_length, data_dim).astype(np.float32)
    _, h_fin, _ = tf.while_loop(
            lambda i, h, x: tf.less(i, sequence_length), # this is called every iteration, if it outputs false, the loop stops
            step,
            [0, h_init, data_point]) # initial data
    loss = tf.reduce_sum(tf.square(h_fin))

    grads = tf.gradients(loss, [W, U, b]) # gradient computation as usual
    print(sess.run(loss))

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


304831920000.0
