# GRU (Gated Recurrent Unit) Cell

### Dependencies

In [1]:
import tensorflow as tf

  return f(*args, **kwds)


### GRU Cell

In [2]:
class GRU(object):
  def __init__(self, input_size, hidden_size):
    self.w_x_z = tf.Variable(tf.truncated_normal([input_size, hidden_size], stddev=0.1))
    self.w_h_z = tf.Variable(tf.truncated_normal([hidden_size, hidden_size], stddev=0.1))
    self.b_z = tf.Variable(tf.truncated_normal([1, hidden_size], stddev=0.1))

    self.w_x_r = tf.Variable(tf.truncated_normal([input_size, hidden_size], stddev=0.1))
    self.w_h_r = tf.Variable(tf.truncated_normal([hidden_size, hidden_size], stddev=0.1))
    self.b_r = tf.Variable(tf.truncated_normal([1, hidden_size], stddev=0.1))

    self.w_x = tf.Variable(tf.truncated_normal([input_size, hidden_size], stddev=0.1))
    self.w_h = tf.Variable(tf.truncated_normal([hidden_size, hidden_size], stddev=0.1))
    self.b = tf.Variable(tf.truncated_normal([1, hidden_size], stddev=0.1))
    
  def __call__(self, x, hprev):
    z = tf.sigmoid(tf.matmul(x, self.w_x_z) + tf.matmul(hprev, self.w_h_z) + self.b_z)
    r = tf.sigmoid(tf.matmul(x, self.w_x_r) + tf.matmul(hprev, self.w_h_r) + self.b_r)
    hi = tf.tanh(tf.matmul(x, self.w_x) + tf.matmul(r * hprev, self.w_h) + self.b)
    h = (1 - z) * hprev + z * hi
    
    return h

### Usage

In [3]:
num_features = 300
num_classes = num_features
num_unrollings = 10
num_gru_nodes = 256
batch_size = 32

xs = tf.placeholder(tf.float32, [num_unrollings, batch_size, num_features])
y = tf.placeholder(tf.float32, [batch_size, num_classes])

state = tf.Variable(tf.zeros([batch_size, num_gru_nodes]), trainable=False)

gru = GRU(num_features, num_gru_nodes)

for i in range(xs.shape[0]):
  state = gru(xs[i], state)

print('state: {}'.format(state))

state: Tensor("add_69:0", shape=(32, 256), dtype=float32)
