# LSTM (Long Short Term Memory) Cell

### Dependencies

In [1]:
import tensorflow as tf

  return f(*args, **kwds)


### LSTM Cell

In [2]:
num_features = 300
num_classes = num_features
num_unrollings = 10
num_lstm_nodes = 256
batch_size = 32

class LSTM(object):
  def __init__(self):
    self.W_f_x = tf.Variable(tf.truncated_normal([num_features, num_lstm_nodes], stddev=0.1))
    self.W_f_hprev = tf.Variable(tf.truncated_normal([num_lstm_nodes, num_lstm_nodes], stddev=0.1))
    self.W_f_b = tf.Variable(tf.truncated_normal([num_lstm_nodes], stddev=0.1))

    self.W_i_x = tf.Variable(tf.truncated_normal([num_features, num_lstm_nodes], stddev=0.1))
    self.W_i_hprev = tf.Variable(tf.truncated_normal([num_lstm_nodes, num_lstm_nodes], stddev=0.1))
    self.W_i_b = tf.Variable(tf.truncated_normal([num_lstm_nodes], stddev=0.1))

    self.W_o_x = tf.Variable(tf.truncated_normal([num_features, num_lstm_nodes], stddev=0.1))
    self.W_o_hprev = tf.Variable(tf.truncated_normal([num_lstm_nodes, num_lstm_nodes], stddev=0.1))
    self.W_o_b = tf.Variable(tf.truncated_normal([num_lstm_nodes], stddev=0.1))

    self.W_ci_x = tf.Variable(tf.truncated_normal([num_features, num_lstm_nodes], stddev=0.1))
    self.W_ci_hprev = tf.Variable(tf.truncated_normal([num_lstm_nodes, num_lstm_nodes], stddev=0.1))
    self.W_ci_b = tf.Variable(tf.truncated_normal([num_lstm_nodes], stddev=0.1))

  def __call__(self, x, hprev, cprev):
    f = tf.sigmoid(tf.matmul(x, self.W_f_x) + tf.matmul(hprev, self.W_f_hprev) + self.W_f_b)
    i = tf.sigmoid(tf.matmul(x, self.W_i_x) + tf.matmul(hprev, self.W_i_hprev) + self.W_i_b)
    o = tf.sigmoid(tf.matmul(x, self.W_o_x) + tf.matmul(hprev, self.W_o_hprev) + self.W_o_b)
    ci = tf.tanh(tf.matmul(x, self.W_ci_x) + tf.matmul(hprev, self.W_ci_hprev) + self.W_ci_b)

    c = f * cprev + i * ci
    h = o * tf.tanh(c)

    return h, c

### Usage

In [3]:
xs = tf.placeholder(tf.float32, [num_unrollings, batch_size, num_features])
y = tf.placeholder(tf.float32, [batch_size, num_classes])

output = tf.Variable(tf.zeros([batch_size, num_lstm_nodes]), trainable=False)
state = tf.Variable(tf.zeros([batch_size, num_lstm_nodes]), trainable=False)

lstm = LSTM()

for i in range(xs.shape[0]):
  output, state = lstm(xs[i], output, state)

print('output: {}'.format(output))
print('state: {}'.format(state))

output: Tensor("mul_29:0", shape=(32, 256), dtype=float32)
state: Tensor("add_89:0", shape=(32, 256), dtype=float32)
