From cdb43ffd7e4b89d6ce8cdadcd62fec46c7f0f7fa Mon Sep 17 00:00:00 2001 From: Ulf Hamster <554c46@gmail.com> Date: Wed, 5 Feb 2020 19:32:25 +0100 Subject: [PATCH] LayernormSimpleRNN moved to addons (#841) * LayernormSimpleRNN moved to addons * code-format run * use super instead of calling the parent class * deactivate layernorm's bias term (beta) for centering, and apply the normal self.bias term after scaling with layernorm for centering. docstring with explanatory formulas added to cell's call method * use_layernorm=True set as default * import alligned with cell.py, examples in docstring corrected * import aligned with cell_test.py * code for LayernormSimpleRNN moved into cell.py and cell_test.py * pylint errors corrected * bazel's timeout increased from small to large for cell_test.py * test with training deactivated * non-ascii char replaced * dict syntax for python2 changed * Renamed to LayerNorm... * direct parent class call replaced with super * error due to import change corrected * uncomment line * unit test added * Name change in unit test file * Still the class name change * deleted dtype and trainable args for parent class * remove self for super parent class calls * compare arrays with assertAllEqual * use_layernorm removed * dict removed from return statement * LayerNormSimpleRNN removed, use kwargs, comments removed * forward **kwargs to other layers * a more pythonic dict loop --- tensorflow_addons/rnn/__init__.py | 1 + tensorflow_addons/rnn/cell.py | 229 +++++++++++++++++++++++++++++ tensorflow_addons/rnn/cell_test.py | 61 ++++++++ 3 files changed, 291 insertions(+) diff --git a/tensorflow_addons/rnn/__init__.py b/tensorflow_addons/rnn/__init__.py index abb082f7ba..a2502052c6 100644 --- a/tensorflow_addons/rnn/__init__.py +++ b/tensorflow_addons/rnn/__init__.py @@ -16,3 +16,4 @@ from tensorflow_addons.rnn.cell import LayerNormLSTMCell from tensorflow_addons.rnn.cell import NASCell +from tensorflow_addons.rnn.cell import LayerNormSimpleRNNCell diff --git a/tensorflow_addons/rnn/cell.py b/tensorflow_addons/rnn/cell.py index ca82108526..83d3d966df 100644 --- a/tensorflow_addons/rnn/cell.py +++ b/tensorflow_addons/rnn/cell.py @@ -363,3 +363,232 @@ def _create_norm_layer(self, name): gamma_initializer=self.norm_gamma_initializer, epsilon=self.norm_epsilon, name=name) + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell): + """Cell class for LayerNormSimpleRNN. + + References: + [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. + "Layer Normalization." ArXiv:1607.06450 [Cs, Stat], + July 21, 2016. http://arxiv.org/abs/1607.06450 + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer uses a bias + vector. + layernorm_epsilon: Float, (default `1e-5`), Small float added to variance + to avoid dividing by zero. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `glorot_uniform`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent + state. Default: `orthogonal`. + bias_initializer: Initializer for the bias vector (`use_bias=True`). + Default: `zeros`. + gamma_initializer: Initializer for the gamma vector of the layer + normalization layer. Default: `ones`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector + (`use_bias=True`). Default: `None`. + gamma_regularizer: Regularizer function applied to the gamma vector + of the layer normalization layer. Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector + (`use_bias=True`). Default: `None`. + gamma_constraint: Constraint function applied to the gamma vector + of the layer normalization layer. Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + + Call arguments: + inputs: A 2D tensor, with shape of `[batch, feature]`. + states: A 2D tensor with shape of `[batch, units]`, which is the state + from the previous time step. For timestep 0, the initial state provided + by the user will be feed to cell. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. Only relevant when `dropout` or + `recurrent_dropout` is used. + + Examples: + + ```python + import numpy as np + import tensorflow.keras as keras + import tensorflow_addons as tfa + + inputs = np.random.random([32, 10, 8]).astype(np.float32) + rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4)) + + output = rnn(inputs) # The output has shape `[32, 4]`. + + rnn = keras.layers.RNN( + tfa.rnn.LayerNormSimpleRNNCell(4), + return_sequences=True, + return_state=True) + + # whole_sequence_output has shape `[32, 10, 4]`. + # final_state has shape `[32, 4]`. + whole_sequence_output, final_state = rnn(inputs) + ``` + """ + + def __init__(self, + units, + activation='tanh', + use_bias=True, + layernorm_epsilon=1e-05, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + gamma_initializer='ones', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + gamma_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + gamma_constraint=None, + dropout=0., + recurrent_dropout=0., + **kwargs): + super(LayerNormSimpleRNNCell, self).__init__( + units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + **kwargs) + self.layernorm = keras.layers.LayerNormalization( + axis=-1, + epsilon=layernorm_epsilon, + center=False, + scale=True, + beta_initializer=None, + gamma_initializer=gamma_initializer, + beta_regularizer=None, + gamma_regularizer=gamma_regularizer, + beta_constraint=None, + gamma_constraint=gamma_constraint, + **kwargs) + + def build(self, input_shape): + super(LayerNormSimpleRNNCell, self).build(input_shape) + self.layernorm.build((None, self.units)) + + def call(self, inputs, states, training=None): + """Formulas. + + Notation: + y_t : Cell output at t (`output`) + y_{t-1} : Previous cell output at t-1 (`prev_output`) + x_t : The new input at t (`inputs`) + W_xh : Weight matrix for inputs x_t (`self.kernel`) + W_hh : Weights for prev. outputs y_{t-1} (`self.recurrent_kernel`) + b : Bias term for centering (`self.bias`) + d1 : Dropout function for x_t (`inputs * dp_mask`) + d2 : Dropout function for y_{t-1} (`prev_output * rec_dp_mask`) + ln : Scaling function from layer normalization (`self.layernorm`) + f : Activation function (`self.activation`) + + Case 1: + Keras' SimpleRNN. Only with bias and activation + y_t = f(x_t * W_xh + y_{t-1} * W_hh + b) + or + net = x_t * W_xh + y_{t-1} * W_hh + y_t = f(net + b) + + Case 2: + addons' LayerNormSimpleRNNCell. Like case 1 but with layer + normalization (only scaling). + y_t = f(ln(x_t * W_xh + y_{t-1} * W_hh) + b) + or + net = x_t * W_xh + y_{t-1} * W_hh + y_t = f(ln(net) + b) + + Layer normalization with scaling and centering in one go (see Ba et + al (2016), page 3, formula 4, https://arxiv.org/abs/1607.06450) + is the same as layer normalization only with scaling, and + centering directly afterwards. + + Case 3: + Keras' SimpleRNN. with dropout, bias, and activation + y_t = f(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + b) + or + net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + y_t = f(net + b) + + Case 4: + addons' LayerNormSimpleRNNCell. Like case 3 but with layer + normalization (only scaling). + y_t = f(ln(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh) + b) + or + net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + y_t = f(ln(net) + b) + """ + prev_output = states[0] + dp_mask = self.get_dropout_mask_for_cell(inputs, training) + rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( + prev_output, training) + + if dp_mask is not None: + h = keras.backend.dot(inputs * dp_mask, self.kernel) + else: + h = keras.backend.dot(inputs, self.kernel) + + # don't add bias to "h" here + # add bias after scaling with layer normalization to "output" + + if rec_dp_mask is not None: + prev_output = prev_output * rec_dp_mask + output = h + keras.backend.dot(prev_output, + self.recurrent_kernel) # "net" + + output = self.layernorm(output) + + if self.bias is not None: + output = keras.backend.bias_add(output, self.bias) + + if self.activation is not None: + output = self.activation(output) + + return output, [output] + + # use SimpleRNNCell's get_initial_state method + + def get_config(self): + cell_config = super(LayerNormSimpleRNNCell, self).get_config() + del cell_config['name'] + + ln_config = self.layernorm.get_config() + ln_config = { + k:v for k, v in ln_config.items() + if k in ["epsilon", "gamma_initializer", + "gamma_regularizer", "gamma_constraint"]} + + ln_config['layernorm_epsilon'] = ln_config.pop("epsilon") + return dict(list(cell_config.items()) + list(ln_config.items())) diff --git a/tensorflow_addons/rnn/cell_test.py b/tensorflow_addons/rnn/cell_test.py index 51cac6a9e3..aac62d8d5d 100644 --- a/tensorflow_addons/rnn/cell_test.py +++ b/tensorflow_addons/rnn/cell_test.py @@ -20,6 +20,7 @@ from tensorflow_addons.utils import test_utils from tensorflow_addons.rnn import cell as rnn_cell +from tensorflow_addons.rnn import LayerNormSimpleRNNCell @test_utils.run_all_in_graph_and_eager_modes @@ -292,5 +293,65 @@ def test_config(self): self.assertEqual(config, restored_config) +@test_utils.run_all_in_graph_and_eager_modes +class LayerNormSimpleRNNTest(tf.test.TestCase): + def test_constraints_layernorm_rnn(self): + embedding_dim = 4 + k_constraint = keras.constraints.max_norm(0.01) + r_constraint = keras.constraints.max_norm(0.01) + b_constraint = keras.constraints.max_norm(0.01) + g_constraint = keras.constraints.max_norm(0.01) + layer = keras.layers.RNN( + LayerNormSimpleRNNCell( + units=5, + kernel_constraint=k_constraint, + recurrent_constraint=r_constraint, + bias_constraint=b_constraint, + gamma_constraint=g_constraint), + input_shape=(None, embedding_dim), + return_sequences=False) + layer.build((None, None, embedding_dim)) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) + self.assertEqual(layer.cell.layernorm.gamma.constraint, g_constraint) + + def test_with_masking_layer_layernorm_rnn(self): + inputs = np.random.random((2, 3, 4)) + targets = np.abs(np.random.random((2, 3, 5))) + targets /= targets.sum(axis=-1, keepdims=True) + model = keras.models.Sequential() + model.add(keras.layers.Masking(input_shape=(3, 4))) + model.add( + keras.layers.RNN( + LayerNormSimpleRNNCell(units=5), + return_sequences=True, + unroll=False)) + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1) + + def test_regularizers_layernorm_rnn(self): + embedding_dim = 4 + layer = keras.layers.RNN( + LayerNormSimpleRNNCell( + units=5, + kernel_regularizer=keras.regularizers.l1(0.01), + recurrent_regularizer=keras.regularizers.l1(0.01), + bias_regularizer='l2', + gamma_regularizer='l2'), + input_shape=(None, embedding_dim), + return_sequences=False) + layer.build((None, None, 2)) + self.assertEqual(len(layer.losses), 4) + + def test_configs_layernorm(self): + config = {'layernorm_epsilon': 1e-6} + cell1 = LayerNormSimpleRNNCell(units=8, **config) + config1 = cell1.get_config() + cell2 = LayerNormSimpleRNNCell(**config1) + config2 = cell2.get_config() + assert config1 == config2 + + if __name__ == "__main__": tf.test.main()