# tf.nn.dynamic_rnn in depth

In this article, we will discuss in depth what the output of `tf.nn.dynamic_rnn()` looks like in various scenaios.

The original code is from [article](https://stats.stackexchange.com/questions/330176/what-is-the-output-of-a-tf-nn-dynamic-rnn) and [
Learning TensorFlow book](http://shop.oreilly.com/product/0636920063698.do)

I add more thorough investigation.

## One Layer with Basic RNN cell 

In [23]:
import tensorflow as tf
import numpy as np

tf.reset_default_graph()
    
input_dim = 3
num_steps = 2
num_units = 5

# None represents the batch size. 
# We put None here, since the batch size can be different for each batch.
X = tf.placeholder(tf.float32, [None, num_steps, input_dim], 'inputs')
seq_length = tf.placeholder(tf.int32, [None], 'seq_length')

basic_cell = tf.contrib.rnn.BasicRNNCell(num_units)

initial_state = basic_cell.zero_state(4, tf.float32)

outputs, final_states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, 
                                          initial_state=initial_state, dtype=tf.float32)

# Create a batch of training examples
# shape (4, 2, 3)
X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [5, 1, 9]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3 
])

print('X batch shape:', X_batch.shape)

seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    outputs_val, states_val = sess.run([outputs, final_states],
                                      feed_dict={X:X_batch, seq_length:seq_length_batch})
    
    print('outputs with shape: ', outputs_val.shape) 
    print(outputs_val)
    print('==================================================================')
    print('h states with shape: ', states_val.shape) 
    print(states_val)

X batch shape: (4, 2, 3)
outputs with shape:  (4, 2, 5)
[[[ 0.0982992   0.91937888  0.03790487 -0.54347765 -0.00692416]
  [ 0.99998456  1.          0.99177212  0.99987334  0.99939644]]

 [[ 0.97107625  0.99995422  0.75952923  0.60435373  0.87907934]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994756   1.          0.96049631  0.96466023  0.99186504]
  [ 0.97324771  0.99998188  0.81755435  0.99450946  0.97564179]]

 [[ 0.99778575  0.98422712  0.97891462  0.99996006  0.99962366]
  [ 0.12011302  0.96819985  0.18604073  0.9268809   0.67240906]]]
h states with shape:  (4, 5)
[[ 0.99998456  1.          0.99177212  0.99987334  0.99939644]
 [ 0.97107625  0.99995422  0.75952923  0.60435373  0.87907934]
 [ 0.97324771  0.99998188  0.81755435  0.99450946  0.97564179]
 [ 0.12011302  0.96819985  0.18604073  0.9268809   0.67240906]]


**Point 1: From code above, we know that:**
* input `X_batch` shape: `(batch_size, num_steps, input_dim)` => (4, 2, 3)
* `outputs_val` with shape:  `(batch_size, num_steps, output_size)` => (4, 2, 5)
* `states_val` with shape:  `(batch_size, state_size)` => (4, 5)

**Point 2: More specifically, we observe that tf.nn.dynamic_rnn returns:**
* `outputs_val` contains hidden states for each sample in a batch over <b>every time step</b>. In this particular example:
    * 2 steps
    * a batch of 4 examples, each has 5 dimensions.

* `states_val` contains the hidden states from last time step
    * final state only involves one time step (i.e., the last one)
    * a batch of 4 states, each for an example. 

**Point 3: Moreover:**
* `state_size` is determined by num_units
* `output_size` is determined by num_units of the last RNN cell (last layer)
    * In this particular example, there is only one cell (i.e., one layer), therefore output_size == state_size, which is 5.

**Point 4: sequence_length of tf.nn.dynamic_rnn**

`sequence_length`: (optional) An int32/int64 vector sized [batch_size]. Used to copy-through state and zero-out outputs when past a batch element's sequence length. So it's more for correctness than performance.

In the above code, we define:
```python
seq_length_batch = np.array([2, 1, 2, 2])
```
which means that the `tf.nn.dynamic_rnn` only handles:
* the first two time steps for the first example of the batch,  
* the first time step for the second example
* the first two time steps for the third example 
* the first two time steps for the fourth example 

> Therefore, we can see that values for the second time step of the second example are zeros in the `outputs`. More importantly, since the first time step of the second example is the last time step that has been handled, the final state for the second example contains the state from the first time step rather than the second time step

the state is a convenient tensor that holds the last actual RNN state, ignoring the zeros. The output tensor holds the outputs of all cells, so it doesn't ignore the zeros. That's the reason for returning both of them.

## One Layer with LSTM RNN cell 

* We replace BasicRNNCell with BasicLSTMCell, which has memory states and hidden states.

In [21]:
import tensorflow as tf
import numpy as np

tf.reset_default_graph()
    
input_dim = 3
num_steps = 2
num_units = 5

# None represents the batch size. 
# We put None here, since the batch size can be different for each batch.
X = tf.placeholder(tf.float32, [None, num_steps, input_dim], 'inputs')
seq_length = tf.placeholder(tf.int32, [None], 'seq_length')

# NOTE: we are using BasicLSTMCell here instead of BasicRNNCell
basic_cell = tf.contrib.rnn.BasicLSTMCell(num_units)

initial_state = basic_cell.zero_state(4, tf.float32)

outputs, final_states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, 
                                          initial_state=initial_state, dtype=tf.float32)

# Create a batch of training examples
# shape (4, 2, 3)
X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [5, 1, 9]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3 
])

print('X batch shape:', X_batch.shape)

seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    outputs_val, states_val = sess.run([outputs, final_states],
                                      feed_dict={X:X_batch, seq_length:seq_length_batch})
    
    print('outputs with shape: ', outputs_val.shape) 
    print(outputs_val)
    print('==================================================================')
    print('c states with shape: ', states_val[0].shape) 
    print('h states with shape: ', states_val[1].shape) 
    print('c states:') 
    print(states_val[0])
    print('h states:') 
    print(states_val[1])
    print('------------------------------------------------------------------')
    print('detail info: ', states_val)

X batch shape: (4, 2, 3)
outputs with shape:  (4, 2, 5)
[[[ -2.67375764e-02  -8.80622715e-02   1.61942363e-01  -1.06723178e-02
     2.22548679e-01]
  [  7.19351619e-02  -3.60482305e-01   3.10523093e-01   8.92076496e-05
     6.73704088e-01]]

 [[  3.88698220e-01  -2.82720029e-01   3.20574433e-01   2.17292667e-03
     4.33897674e-01]
  [  0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
     0.00000000e+00]]

 [[  4.13269788e-01  -3.30032617e-01   3.21764827e-01   6.56747579e-05
     5.19900203e-01]
  [  3.43061060e-01  -4.45168853e-01   3.73977542e-01   3.55326873e-03
     7.03240216e-01]]

 [[  3.24896909e-02  -8.82205814e-02   6.35052145e-01   1.20023526e-01
     4.86850947e-01]
  [  1.28061742e-01  -2.86177725e-01   3.25985193e-01   1.07136607e-01
     5.41655004e-01]]]
c states with shape:  (4, 5)
h states with shape:  (4, 5)
c states:
[[ 0.07384271 -0.38438019  1.06630743  0.08183971  0.82870448]
 [ 0.49419126 -0.35714117  0.8639577   0.10446893  0.51287788]
 [ 0.3

**Point 1: From code above, we know that:**
* input `X_batch` shape: `(batch_size, num_steps, input_dim)` => (4, 2, 3)
* `outputs_val` with shape:  `(batch_size, num_steps, output_size)` => (4, 2, 5)
* <b style="color:red">NOTE</b> `states_val` is a `LSTMStateTuple` containing two elements:
    * `states_val[0]` is the internal memory states with shape:  `(batch_size, state_size)` => (4, 5)
    * `states_val[1]` is the last step hidden states with shape:  `(batch_size, state_size)` => (4, 5)

**Point 2: More specifically, we observe that tf.nn.dynamic_rnn returns:**
* `outputs_val` contains hidden states for each sample in a batch over <b>every time step</b>. In this particular example:
    * 2 steps
    * a batch of 4 examples, each has 5 dimensions.
* <b style="color:red">NOTE</b> `states_val[0]` contains the memory states from last time step (Normally we do not use this)
* <b style="color:red">NOTE</b> `states_val[1]` contains the hidden states from last time step
    * final state only involves one time step (i.e., the last one)
    * a batch of 4 states, each for an example. 


## Multiple Layer with LSTM RNN cell 

* We construct two layer RNN with LSTM cells

In [11]:
def build_cell(num_units):     
    lstm = tf.contrib.rnn.BasicLSTMCell(num_units)     
    return lstm  

In [25]:
import tensorflow as tf
import numpy as np

tf.reset_default_graph()
    
input_dim = 3
num_steps = 2
num_units = [4, 5]
num_layers = 2

# None represents the batch size. 
# We put None here, since the batch size can be different for each batch.
X = tf.placeholder(tf.float32, [None, num_steps, input_dim], 'inputs')
seq_length = tf.placeholder(tf.int32, [None], 'seq_length')

# We are using BasicLSTMCell here instead of BasicRNNCell
basic_cell = tf.contrib.rnn.MultiRNNCell([build_cell(num_units[idx]) for idx in range(num_layers)])

initial_state = basic_cell.zero_state(4, tf.float32)

outputs, final_states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, 
                                          initial_state=initial_state, dtype=tf.float32)

# Create a batch of training examples
# shape (4, 2, 3)
X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [5, 1, 9]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3 
])

print('X batch shape:', X_batch.shape)

seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    outputs_val, states_val = sess.run([outputs, final_states],
                                      feed_dict={X:X_batch, seq_length:seq_length_batch})
    
    print('outputs with shape: ', outputs_val.shape) 
    print(outputs_val)
    print('==================================================================')
    print('First Layer c states with shape: ', states_val[0][0].shape) 
    print('First Layer h states with shape: ', states_val[0][1].shape) 
    print('First Layer c states:') 
    print(states_val[0][0])
    print('First Layer h states:') 
    print(states_val[0][1])
    print('------------------------------------------------------------------')
    print('Second (last) Layer c states with shape: ', states_val[1][0].shape) 
    print('Second (last) Layer h states with shape: ', states_val[1][1].shape) 
    print('Last Layer c states:') 
    print(states_val[1][0])
    print('Last Layer h states:') 
    print(states_val[1][1])
    print('------------------------------------------------------------------')
    print('detail info: ', states_val)

X batch shape: (4, 2, 3)
outputs with shape:  (4, 2, 5)
[[[ 0.0065024  -0.017124   -0.01359156  0.00859271  0.00505634]
  [ 0.03826391 -0.09002061 -0.05047503  0.06010268  0.02036073]]

 [[ 0.02109389 -0.05707591 -0.03460117  0.03494828  0.02418501]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.02550699 -0.06566202 -0.03883148  0.04253773  0.02695743]
  [ 0.04090386 -0.12728092 -0.07308862  0.08742567  0.04232676]]

 [[ 0.01300943 -0.0356161  -0.02754521  0.02426844  0.01893609]
  [ 0.01473426 -0.0594526  -0.07691813  0.0522584   0.04810004]]]
First Layer c states with shape:  (4, 4)
First Layer h states with shape:  (4, 4)
First Layer c states:
[[-0.72883856  0.19821936  0.94745511 -0.10293541]
 [-0.38816008  0.09462205  0.69282103 -0.0132799 ]
 [-0.91657662  0.17569107  1.29657149 -0.00464094]
 [-1.31906509 -0.35433936  0.99494326  0.00202722]]
First Layer h states:
[[-0.15535693  0.05520145  0.73585266 -0.06237983]
 [-0.21082501  0.04196973  0.58008361 -0.0

**Point 1: From code above, we know that:**
* `outputs_val` is the same as the previous two scenarios and still has shape:  `(batch_size, num_steps, output_size)` => (4, 2, 5) (with different values, of course)
* <b style="color:red">NOTE</b> `states_val` is a tuple of tuples. The outer tuple contains elements representing layers (2 in this particular example), each of which contains a `LSTMStateTuple` that in turn contains two elements, one for memory states and the other for hidden states:
    * `states_val[0]` is the first layer LSTMStateTuple
        * `states_val[0][0]` is the first layer memory states with shape:  `(batch_size, state_size)` => (4, 4)
        * `states_val[0][1]` is the first layer last step hidden states with shape:  `(batch_size, state_size)` => (4, 4)
    * `states_val[1]` is the second layer LSTMStateTuple
        * `states_val[1][0]` is the second layer memory states with shape:  `(batch_size, state_size)` => (4, 5)
        * `states_val[1][1]` is the second layer last step hidden states with shape:  `(batch_size, state_size)` => (4, 5)
        
**Point 2: Moreover:**
* `state_size` of memory states and hidden states in a layer are the same and is determined by `num_units` of LSTM cell in that layer.
    * In the first layer, the state_size is 4, while 
    * In thw second layer, the state_size is 5
* <b style="color:red">NOTE</b> `output_size` is determined by `num_units` of the last LSTM cell (last layer)
    * In this particular example, the `num_units` of last (second) layer is 5. therefor the `output_size` is 5
        

## Summary

* `outputs` contains hidden states for each sample (in a batch) over <b>every time step</b> of the <b>last layer</b>
* `states` contains hidden states (and memory states) for each sample (in a batch) over <b>every layers</b> at the <b>last time step</b>