Skip to content

Commit

Permalink
LayernormSimpleRNN moved to addons (#841)
Browse files Browse the repository at this point in the history
* LayernormSimpleRNN moved to addons

* code-format run

* use super instead of calling the parent class

* deactivate layernorm's bias term (beta) for centering, and apply the normal self.bias term after scaling with layernorm for centering. docstring with explanatory formulas added to cell's call method

* use_layernorm=True set as default

* import alligned with cell.py, examples in docstring corrected

* import aligned with cell_test.py

* code for LayernormSimpleRNN moved into cell.py and cell_test.py

* pylint errors corrected

* bazel's timeout increased from small to large for cell_test.py

* test with training deactivated

* non-ascii char replaced

* dict syntax for python2 changed

* Renamed to LayerNorm...

* direct parent class call replaced with super

* error due to import change corrected

* uncomment line

* unit test added

* Name change in unit test file

* Still the class name change

* deleted dtype and trainable args for parent class

* remove self for super parent class calls

* compare arrays with assertAllEqual

* use_layernorm removed

* dict removed from return statement

* LayerNormSimpleRNN removed, use kwargs, comments removed

* forward **kwargs to other layers

* a more pythonic dict loop
  • Loading branch information
ulf1 committed Feb 5, 2020
1 parent 52b8079 commit cdb43ff
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow_addons/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@

from tensorflow_addons.rnn.cell import LayerNormLSTMCell
from tensorflow_addons.rnn.cell import NASCell
from tensorflow_addons.rnn.cell import LayerNormSimpleRNNCell
229 changes: 229 additions & 0 deletions tensorflow_addons/rnn/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,232 @@ def _create_norm_layer(self, name):
gamma_initializer=self.norm_gamma_initializer,
epsilon=self.norm_epsilon,
name=name)


@tf.keras.utils.register_keras_serializable(package='Addons')
class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell):
"""Cell class for LayerNormSimpleRNN.
References:
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
"Layer Normalization." ArXiv:1607.06450 [Cs, Stat],
July 21, 2016. http://arxiv.org/abs/1607.06450
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias
vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer. Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer. Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer. Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
for the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state
from the previous time step. For timestep 0, the initial state provided
by the user will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
import numpy as np
import tensorflow.keras as keras
import tensorflow_addons as tfa
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = keras.layers.RNN(
tfa.rnn.LayerNormSimpleRNNCell(4),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""

def __init__(self,
units,
activation='tanh',
use_bias=True,
layernorm_epsilon=1e-05,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
super(LayerNormSimpleRNNCell, self).__init__(
units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
**kwargs)
self.layernorm = keras.layers.LayerNormalization(
axis=-1,
epsilon=layernorm_epsilon,
center=False,
scale=True,
beta_initializer=None,
gamma_initializer=gamma_initializer,
beta_regularizer=None,
gamma_regularizer=gamma_regularizer,
beta_constraint=None,
gamma_constraint=gamma_constraint,
**kwargs)

def build(self, input_shape):
super(LayerNormSimpleRNNCell, self).build(input_shape)
self.layernorm.build((None, self.units))

def call(self, inputs, states, training=None):
"""Formulas.
Notation:
y_t : Cell output at t (`output`)
y_{t-1} : Previous cell output at t-1 (`prev_output`)
x_t : The new input at t (`inputs`)
W_xh : Weight matrix for inputs x_t (`self.kernel`)
W_hh : Weights for prev. outputs y_{t-1} (`self.recurrent_kernel`)
b : Bias term for centering (`self.bias`)
d1 : Dropout function for x_t (`inputs * dp_mask`)
d2 : Dropout function for y_{t-1} (`prev_output * rec_dp_mask`)
ln : Scaling function from layer normalization (`self.layernorm`)
f : Activation function (`self.activation`)
Case 1:
Keras' SimpleRNN. Only with bias and activation
y_t = f(x_t * W_xh + y_{t-1} * W_hh + b)
or
net = x_t * W_xh + y_{t-1} * W_hh
y_t = f(net + b)
Case 2:
addons' LayerNormSimpleRNNCell. Like case 1 but with layer
normalization (only scaling).
y_t = f(ln(x_t * W_xh + y_{t-1} * W_hh) + b)
or
net = x_t * W_xh + y_{t-1} * W_hh
y_t = f(ln(net) + b)
Layer normalization with scaling and centering in one go (see Ba et
al (2016), page 3, formula 4, https://arxiv.org/abs/1607.06450)
is the same as layer normalization only with scaling, and
centering directly afterwards.
Case 3:
Keras' SimpleRNN. with dropout, bias, and activation
y_t = f(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + b)
or
net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
y_t = f(net + b)
Case 4:
addons' LayerNormSimpleRNNCell. Like case 3 but with layer
normalization (only scaling).
y_t = f(ln(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh) + b)
or
net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
y_t = f(ln(net) + b)
"""
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)

if dp_mask is not None:
h = keras.backend.dot(inputs * dp_mask, self.kernel)
else:
h = keras.backend.dot(inputs, self.kernel)

# don't add bias to "h" here
# add bias after scaling with layer normalization to "output"

if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + keras.backend.dot(prev_output,
self.recurrent_kernel) # "net"

output = self.layernorm(output)

if self.bias is not None:
output = keras.backend.bias_add(output, self.bias)

if self.activation is not None:
output = self.activation(output)

return output, [output]

# use SimpleRNNCell's get_initial_state method

def get_config(self):
cell_config = super(LayerNormSimpleRNNCell, self).get_config()
del cell_config['name']

ln_config = self.layernorm.get_config()
ln_config = {
k:v for k, v in ln_config.items()
if k in ["epsilon", "gamma_initializer",
"gamma_regularizer", "gamma_constraint"]}

ln_config['layernorm_epsilon'] = ln_config.pop("epsilon")
return dict(list(cell_config.items()) + list(ln_config.items()))
61 changes: 61 additions & 0 deletions tensorflow_addons/rnn/cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tensorflow_addons.utils import test_utils
from tensorflow_addons.rnn import cell as rnn_cell
from tensorflow_addons.rnn import LayerNormSimpleRNNCell


@test_utils.run_all_in_graph_and_eager_modes
Expand Down Expand Up @@ -292,5 +293,65 @@ def test_config(self):
self.assertEqual(config, restored_config)


@test_utils.run_all_in_graph_and_eager_modes
class LayerNormSimpleRNNTest(tf.test.TestCase):
def test_constraints_layernorm_rnn(self):
embedding_dim = 4
k_constraint = keras.constraints.max_norm(0.01)
r_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
g_constraint = keras.constraints.max_norm(0.01)
layer = keras.layers.RNN(
LayerNormSimpleRNNCell(
units=5,
kernel_constraint=k_constraint,
recurrent_constraint=r_constraint,
bias_constraint=b_constraint,
gamma_constraint=g_constraint),
input_shape=(None, embedding_dim),
return_sequences=False)
layer.build((None, None, embedding_dim))
self.assertEqual(layer.cell.kernel.constraint, k_constraint)
self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
self.assertEqual(layer.cell.bias.constraint, b_constraint)
self.assertEqual(layer.cell.layernorm.gamma.constraint, g_constraint)

def test_with_masking_layer_layernorm_rnn(self):
inputs = np.random.random((2, 3, 4))
targets = np.abs(np.random.random((2, 3, 5)))
targets /= targets.sum(axis=-1, keepdims=True)
model = keras.models.Sequential()
model.add(keras.layers.Masking(input_shape=(3, 4)))
model.add(
keras.layers.RNN(
LayerNormSimpleRNNCell(units=5),
return_sequences=True,
unroll=False))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)

def test_regularizers_layernorm_rnn(self):
embedding_dim = 4
layer = keras.layers.RNN(
LayerNormSimpleRNNCell(
units=5,
kernel_regularizer=keras.regularizers.l1(0.01),
recurrent_regularizer=keras.regularizers.l1(0.01),
bias_regularizer='l2',
gamma_regularizer='l2'),
input_shape=(None, embedding_dim),
return_sequences=False)
layer.build((None, None, 2))
self.assertEqual(len(layer.losses), 4)

def test_configs_layernorm(self):
config = {'layernorm_epsilon': 1e-6}
cell1 = LayerNormSimpleRNNCell(units=8, **config)
config1 = cell1.get_config()
cell2 = LayerNormSimpleRNNCell(**config1)
config2 = cell2.get_config()
assert config1 == config2


if __name__ == "__main__":
tf.test.main()

0 comments on commit cdb43ff

Please sign in to comment.