# Tensorflow Dynamic RNN Loop Implementation

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
tf.reset_default_graph()
tf.set_random_seed(1)

In [3]:
a = tf.ones((3, 3))
b = tf.ones((3, 3))
c = tf.ones((3, 3))

ar, br, cr = tf.while_loop(
  cond=lambda a, b, c: tf.reduce_mean(a) < 5, 
  body=lambda a, b, c: (a + b + c, b, c),
  loop_vars=(a, b, c))

In [4]:
# final_t, final_state, _ = tf.while_loop(
#   cond=lambda t, st, sl: t < max_len,
#   body=lambda t, st, sl: (t + 1, st + 2.0 * tf.one_hot(toi32(t < seq_len) * t - toi32(t >= seq_len), st.shape[1]), sl),
#   loop_vars=(0, state, seq_len))

In [5]:
inputs = tf.constant([
  [[1], [1], [1], [8], [8]],
  [[1], [1], [8], [8], [8]],
  [[1], [8], [8], [8], [8]],
], dtype=tf.float32, name='inputs')

state = tf.constant([
  [10, 10], 
  [10, 10], 
  [10, 10]
], dtype=tf.float32, name='state')

seq_len = tf.constant([3, 2, 1])
# max_len = tf.reduce_max(seq_len)
max_len = inputs.shape[1]

def toi32(x):
  return tf.cast(x, tf.int32)

cell = tf.nn.rnn_cell.GRUCell(2)
def f(x, st):
  return cell(x, st)[1]

def reduce(xs, st, sts, t, sl):
  valid = t < sl
  comp = f(xs[:, t], st)
  new_st = tf.where(valid, comp, st)
  append = tf.where(valid, comp, tf.zeros(st.shape))
  new_sts = tf.concat([
    sts, 
    tf.reshape(append, (append.shape[0], 1, append.shape[1]))
  ], axis=1)
  return new_st, new_sts

_, final_state, all_states, final_t, _ = tf.while_loop(
  cond=lambda xs, st, sts, t, sl: t < max_len,
  body=lambda xs, st, sts, t, sl: (xs, *reduce(xs, st, sts, t, sl), t + 1, sl),
  loop_vars=(inputs, state, tf.zeros((state.shape[0], 0, state.shape[1])), 0, seq_len),
  shape_invariants=(inputs.shape, state.shape, tf.TensorShape([state.shape[0], None, state.shape[1]]), tf.TensorShape([]), seq_len.shape))

# reference implementation
(all_states_true, final_state_true) = tf.nn.dynamic_rnn(
  cell,
  inputs=inputs,
  initial_state=state,
  sequence_length=seq_len)

In [6]:
init = tf.global_variables_initializer()

In [7]:
with tf.Session() as sess:
  sess.run(init)
  
  print('while loop:')
  av, bv, cv = sess.run([ar, br, cr])
  print(av)
  print(bv)
  print(cv)
  
  print('rnn loop:')
  stv1, stsv1, tv = sess.run([final_state, all_states, final_t])
  print('t: {}'.format(tv))
  print('st:\n{}'.format(stv1))
  print('sts:\n{}'.format(stsv1))
  
  print('tf rnn loop:')
  stv2, stsv2 = sess.run([final_state_true, all_states_true])
  print('st:\n{}'.format(stv2))
  print('sts:\n{}'.format(stsv2))
  
  assert(np.all(stv1 == stv2))
  assert(np.all(stsv1 == stsv2))

while loop:
[[ 5.  5.  5.]
 [ 5.  5.  5.]
 [ 5.  5.  5.]]
[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]
[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]
rnn loop:
t: 5
st:
[[ 0.99985451  9.29173756]
 [ 1.0003612   9.45623398]
 [ 1.04323864  9.61685848]]
sts:
[[[ 1.04323864  9.61685848]
  [ 1.0003612   9.45623398]
  [ 0.99985451  9.29173756]
  [ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 1.04323864  9.61685848]
  [ 1.0003612   9.45623398]
  [ 0.          0.        ]
  [ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 1.04323864  9.61685848]
  [ 0.          0.        ]
  [ 0.          0.        ]
  [ 0.          0.        ]
  [ 0.          0.        ]]]
tf rnn loop:
st:
[[ 0.99985451  9.29173756]
 [ 1.0003612   9.45623398]
 [ 1.04323864  9.61685848]]
sts:
[[[ 1.04323864  9.61685848]
  [ 1.0003612   9.45623398]
  [ 0.99985451  9.29173756]
  [ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 1.04323864  9.61685848]
  [ 1.0003612   9.45623398]
  [ 0.          