Skip to content

Commit

Permalink
Weight normalization for RNN Cells.
Browse files Browse the repository at this point in the history
The current RNN implementation executes a user defined
function (the call() method of subclasses of RNNCells) inside a
tf.while() loop. Weight normalisation requires a one-time
 normalization of the transition matrices prior to
entering the while loop. The following 2 edits have been made in
tensorflow/python/ops to enable this functionality:

 - RNNCell now has a prepare() method. It does nothing as implemented
in the base class
 - A call to cell.prepare() has been added just before entering
 _dynamic_rnn_loop()

Subclasses of RNNCell may implement normalization in the cell's
prepare() method. One implementation with BasicLSTMCell
and associated tests have been added to contrib. Note that any
wrappers to be used with a weight-normalized cell need to be
appropriately subclassed, as illustrated with the
PrepareableMultiRNNCell example in contrib.
  • Loading branch information
discoveredcheck committed Aug 8, 2017
1 parent a5066f6 commit e61e29e
Show file tree
Hide file tree
Showing 2 changed files with 377 additions and 0 deletions.
128 changes: 128 additions & 0 deletions tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.training import adam
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
Expand Down Expand Up @@ -1366,5 +1368,131 @@ def benchmarkDynamicRNNWithMultiLSTMCell(self):
benchmark_results["wall_time"]]]))


class WeightNormBasicLSTMCellTest(test.TestCase):
"""Compared cell output with pre-calculated values."""

def _cell_output(self, cell):
"""Calculate cell output"""

with self.test_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])

#cell = contrib_rnn_cell.WeightNormLSTMCell(2, norm, peep)
#cell = rnn_cell.LSTMCell(2, use_peepholes=True)
state0 = rnn_cell.LSTMStateTuple(c0, h0)

xout, sout = cell()(x, state0)

sess.run([variables.global_variables_initializer()])
res = sess.run([xout, sout], {
x.name: np.array([[1., 1.]]),
c0.name: 0.1 * np.asarray([[0, 1]]),
h0.name: 0.1 * np.asarray([[2, 3]]),
})

actual_state_c = res[1].c
actual_state_h = res[1].h

return actual_state_c, actual_state_h

def testBasicCell(self):
"""Tests cell w/o peepholes and w/o normalisation"""

cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,
norm=False,
use_peepholes=False)

actual_c, actual_h = self._cell_output(cell)

expected_c = np.array([[0.65937078, 0.74983585]])
expected_h = np.array([[0.44923624, 0.49362513]])

self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)

def testNonbasicCell(self):
"""Tests cell with peepholes and w/o normalisation"""
cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,
norm=False,
use_peepholes=True)
actual_c, actual_h = self._cell_output(cell)

expected_c = np.array([[0.65937084, 0.7574988]])
expected_h = np.array([[0.4792085, 0.53470564]])

self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)


def testBasicCellWithNorm(self):
"""Tests cell w/o peepholes and with normalisation"""

cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,
norm=True,
use_peepholes=False)

actual_c, actual_h = self._cell_output(cell)

expected_c = np.array([[0.50125383, 0.58805949]])
expected_h = np.array([[0.32770363, 0.37397948]])

self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)

def testNonBasicCellWithNorm(self):
"""Tests cell with peepholes and with normalisation"""

cell = lambda: contrib_rnn_cell.WeightNormLSTMCell(2,
norm=True,
use_peepholes=True)

actual_c, actual_h = self._cell_output(cell)

expected_c = np.array([[0.50125383, 0.59587258]])
expected_h = np.array([[0.35041603, 0.40873795]])

self.assertAllClose(expected_c, actual_c, 1e-5)
self.assertAllClose(expected_h, actual_h, 1e-5)

def testBackProp(self):
"""Test a fully-featured cell with backprop.
Only a smoke test, no calculations are checked here."""

batch_size = 20
time_size = 30
num_units = 40
input_dim = 100
output_dim = 10

with self.test_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root", initializer=init):
func = lambda: contrib_rnn_cell.WeightNormLSTMCell(num_units,
norm=True,
use_peepholes=True,
cell_clip=0.1,
num_proj=output_dim)
cell = rnn_cell.MultiRNNCell([func() for _ in range(2)])
x = array_ops.constant(np.random.randn(batch_size,
time_size,
input_dim),
dtype=dtypes.float32)
y = array_ops.constant(np.random.randn(batch_size,
time_size,
output_dim))
x_out, _ = rnn.dynamic_rnn(cell,
inputs=x,
dtype=dtypes.float32,
swap_memory=True)
loss = losses_impl.mean_squared_error(x_out, y)
opt = adam.AdamOptimizer(0.001).minimize(loss)
sess.run(variables.global_variables_initializer())
sess.run(opt)

if __name__ == "__main__":
test.main()
Loading

0 comments on commit e61e29e

Please sign in to comment.