diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 7a6ce96fb..4c63ce8ba 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -292,7 +292,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): """Conditional conv_fn making kernel 1d or 2d depending on inputs shape.""" static_shape = inputs.get_shape() if not static_shape or len(static_shape) != 4: - raise ValueError("Inputs to conv must have statically known rank 4.") + raise ValueError("Inputs to conv must have statically known rank 4. Shape:" +str(static_shape)) # Add support for left padding. if "padding" in kwargs and kwargs["padding"] == "LEFT": dilation_rate = (1, 1) @@ -1378,3 +1378,128 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence): xentropy = tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=soft_targets) return xentropy - normalizing + + +def global_pool_1d(inputs, pooling_type='MAX', mask=None): + """ + Pools elements across the last dimension. Useful to a list of vectors into a + single vector to get a representation of a set. + + Args + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + pooling_type: the pooling type to use, MAX or AVR + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + Outputs + output: A tensor of dimensions batch_size x input_dims + dimension containing the sequences of transformed vectors. + """ + + with tf.name_scope("global_pool", [inputs]): + if mask is not None: + mask = tf.expand_dims(mask, axis=2) + inputs = tf.multiply(inputs, mask) + + if pooling_type == 'MAX': + # A tf.pool can be used here, but reduce is cleaner + output = tf.reduce_max(inputs, axis=1) + elif pooling_type == 'AVR': + if mask is not None: + # Some elems are dummy elems so we can't just reduce the average + output = tf.reduce_sum(inputs, axis=1) + num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True) + output = tf.div(output, num_elems) + #N.B: this will cause a NaN if one batch contains no elements + else: + output = tf.reduce_mean(inputs, axis=1) + + return output + + +def linear_set_layer(layer_size, + inputs, + context=None, + activation_fn=tf.nn.relu, + dropout=0.0, + name=None): + """ + Basic layer type for doing funky things with sets. + Applies a linear transformation to each element in the input set. + If a context is supplied, it is concatenated with the inputs. + e.g. One can use global_pool_1d to get a representation of the set which + can then be used as the context for the next layer. + + Args + layer_size: Dimension to transform the input vectors to + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + context: A tensor of dimensions batch_size x context_dims + containing a global statistic about the set. + dropout: Dropout probability. + activation_fn: The activation function to use. + Outputs + output: A tensor of dimensions batch_size x sequence_length x output_dims + dimension containing the sequences of transformed vectors. + + TODO: Add bias add. + """ + + with tf.variable_scope(name, "linear_set_layer", [inputs]): + # Apply 1D convolution to apply linear filter to each element along the 2nd + # dimension + #in_size = inputs.get_shape().as_list()[-1] + outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv") + + # Apply the context if it exists + if context is not None: + # Unfortunately tf doesn't support broadcasting via concat, but we can + # simply add the transformed context to get the same effect + context = tf.expand_dims(context, axis=1) + #context_size = context.get_shape().as_list()[-1] + cont_tfm = conv1d(context, layer_size, 1, + activation=None, name="cont_conv") + outputs += cont_tfm + + if activation_fn is not None: + outputs = activation_fn(outputs) + + if dropout != 0.0: + output = tf.nn.dropout(output, 1.0 - dropout) + + return outputs + + +def ravanbakhsh_set_layer(layer_size, + inputs, + mask=None, + activation_fn=tf.nn.tanh, + dropout=0.0, + name=None): + """ + Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 + More parameter-efficient verstion of a linear-set-layer with context. + + + Args + layer_size: Dimension to transform the input vectors to. + inputs: A tensor of dimensions batch_size x sequence_length x vector + containing the sequences of input vectors. + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + activation_fn: The activation function to use. + Outputs + output: A tensor of dimensions batch_size x sequence_length x vector + dimension containing the sequences of transformed vectors. + """ + + with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]): + output = linear_set_layer( + layer_size, + inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), + activation_fn=activation_fn, + name=name) + + return output + + diff --git a/tensor2tensor/models/common_layers_test.py b/tensor2tensor/models/common_layers_test.py index 8d2b4dec1..04d428884 100644 --- a/tensor2tensor/models/common_layers_test.py +++ b/tensor2tensor/models/common_layers_test.py @@ -50,7 +50,7 @@ def testSaturatingSigmoid(self): self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0]) def testFlatten4D3D(self): - x = np.random.random_integers(1, high=8, size=(3, 5, 2)) + x = np.random.randint(1, 9, size=(3, 5, 2)) with self.test_session() as session: y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7)) session.run(tf.global_variables_initializer()) @@ -58,7 +58,7 @@ def testFlatten4D3D(self): self.assertEqual(res.shape, (3, 5 * 2, 7)) def testEmbedding(self): - x = np.random.random_integers(1, high=8, size=(3, 5)) + x = np.random.randint(1, 9, size=(3, 5)) with self.test_session() as session: y = common_layers.embedding(x, 10, 16) session.run(tf.global_variables_initializer()) @@ -81,6 +81,14 @@ def testConv(self): session.run(tf.global_variables_initializer()) res = session.run(y) self.assertEqual(res.shape, (5, 5, 1, 13)) + + def testConv1d(self): + x = np.random.rand(5, 7, 11) + with self.test_session() as session: + y = common_layers.conv1d(tf.constant(x, dtype=tf.float32), 13, 1) + session.run(tf.global_variables_initializer()) + res = session.run(y) + self.assertEqual(res.shape, (5, 7, 13)) def testSeparableConv(self): x = np.random.rand(5, 7, 1, 11) @@ -293,6 +301,66 @@ def testDeconvStride2MultiStep(self): session.run(tf.global_variables_initializer()) actual = session.run(a) self.assertEqual(actual.shape, (5, 32, 1, 16)) + + def testGlobalPool1d(self): + shape = (5, 4) + x1 = np.random.rand(5,4,11) + #mask = np.random.randint(2, size=shape) + no_mask = np.ones((5,4)) + full_mask = np.zeros((5,4)) + + with self.test_session() as session: + x1_ = tf.Variable(x1, dtype=tf.float32) + no_mask_ = tf.Variable(no_mask, dtype=tf.float32) + full_mask_ = tf.Variable(full_mask, dtype=tf.float32) + + none_mask_max = common_layers.global_pool_1d(x1_) + no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_) + result1 = tf.reduce_sum(none_mask_max - no_mask_max) + + full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_) + result2 = tf.reduce_sum(full_mask_max) + + none_mask_avr = common_layers.global_pool_1d(x1_, 'AVR') + no_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', no_mask_) + result3 = tf.reduce_sum(none_mask_avr - no_mask_avr) + + full_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', full_mask_) + result4 = tf.reduce_sum(full_mask_avr) + + session.run(tf.global_variables_initializer()) + actual = session.run([result1, result2, result3, result4]) + # N.B: Last result will give a NaN. + self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0]) + + + def testLinearSetLayer(self): + x1 = np.random.rand(5,4,11) + cont = np.random.rand(5,13) + with self.test_session() as session: + x1_ = tf.Variable(x1, dtype=tf.float32) + cont_ = tf.Variable(cont, dtype=tf.float32) + + simple_ff = common_layers.linear_set_layer(32, x1_) + cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_) + + session.run(tf.global_variables_initializer()) + actual = session.run([simple_ff, cont_ff]) + self.assertEqual(actual[0].shape, (5,4,32)) + self.assertEqual(actual[1].shape, (5,4,32)) + + def testRavanbakhshSetLayer(self): + x1 = np.random.rand(5,4,11) + cont = np.random.rand(5,13) + with self.test_session() as session: + x1_ = tf.Variable(x1, dtype=tf.float32) + cont_ = tf.Variable(cont, dtype=tf.float32) + + layer = common_layers.ravanbakhsh_set_layer(32, x1_) + + session.run(tf.global_variables_initializer()) + actual = session.run(layer) + self.assertEqual(actual.shape, (5,4,32)) if __name__ == "__main__": diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py new file mode 100644 index 000000000..90fea6139 --- /dev/null +++ b/tensor2tensor/models/transformer_alternative.py @@ -0,0 +1,190 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Alternative transformer network using different layer types to demonstrate + alternatives to self attention. + + Code is mostly copied from original Transformer source (if that wasn't + already obvious). + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +# Dependency imports + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.models import common_attention +from tensor2tensor.models import common_hparams +from tensor2tensor.models import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +@registry.register_model +class TransformerAlt(t2t_model.T2TModel): + + def model_fn_body(self, features): + # + + # Remove dropout if not training + hparams = copy.copy(self._hparams) + targets = features["targets"] + inputs = features.get("inputs") + target_space = features.get("target_space_id") + + inputs = common_layers.flatten4d3d(inputs) + targets = common_layers.flatten4d3d(targets) + + (encoder_input, encoder_attention_bias, _) = (transformer.\ + transformer_prepare_encoder(inputs, target_space, hparams) ) + (decoder_input, decoder_self_attention_bias) = transformer.\ + transformer_prepare_decoder(targets, hparams) + + # We need masks of the form batch size x input sequences + # Biases seem to be of the form batch_size x 1 x input sequences x vec dim + # Squeeze out dim one, and get the first element of each vector + encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:,:,0] + decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:,:,0] + + def residual_fn(x, y): + return common_layers.layer_norm(x + tf.nn.dropout( + y, 1.0 - hparams.residual_dropout)) + + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) + decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) + encoder_output = alt_transformer_encoder( + encoder_input, residual_fn, encoder_mask, hparams) + + decoder_output = alt_transformer_decoder( + decoder_input, encoder_output, residual_fn, decoder_mask, + encoder_attention_bias, hparams) + + decoder_output = tf.expand_dims(decoder_output, 2) + + return decoder_output + + + +def composite_layer(inputs, mask, hparams): + x = inputs + + # Applies ravanbakhsh on top of each other + if hparams.composite_layer_type == "ravanbakhsh": + for layer in xrange(hparams.layers_per_layer): + with tf.variable_scope(".%d" % layer): + x = common_layers.ravanbakhsh_set_layer( + hparams.hidden_size, + x, + mask=mask, + dropout=0.0) + + # Transforms elements to get a context, and then uses this in a final layer + elif hparams.composite_layer_type == "reembedding": + initial_elems = x + # Transform elements n times and then pool + for layer in xrange(hparams.layers_per_layer): + with tf.variable_scope(".%d" % layer): + x = common_layers.linear_set_layer( + hparams.hidden_size, + x, + dropout=0.0) + context = common_layers.global_pool_1d(x, mask=mask) + + #Final layer + x = common_layers.linear_set_layer( + hparams.hidden_size, + x, + context=context, + dropout=0.0) + + return x + + + +def alt_transformer_encoder(encoder_input, + residual_fn, + mask, + hparams, + name="encoder"): + + x = encoder_input + + # Summaries don't work in multi-problem setting yet. + summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 + + with tf.variable_scope(name): + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % layer): + x = residual_fn(x, composite_layer(x, mask, hparams)) + + return x + + +def alt_transformer_decoder(decoder_input, + encoder_output, + residual_fn, + mask, + encoder_decoder_attention_bias, + hparams, + name="decoder"): + + x = decoder_input + + # Summaries don't work in multi-problem setting yet. + summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 + with tf.variable_scope(name): + for layer in xrange(hparams.num_hidden_layers): + with tf.variable_scope("layer_%d" % layer): + + x_ = common_attention.multihead_attention( + x, + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=summaries, + name="encdec_attention") + + x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) + x = residual_fn(x, x_) + + return x + + + + + +@registry.register_hparams +def transformer_alt(): + """Set of hyperparameters.""" + hparams = transformer.transformer_base() + hparams.batch_size = 64 + hparams.add_hparam("layers_per_layer", 4) + #hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding + hparams.add_hparam("composite_layer_type", "reembedding") + return hparams + diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 9535558a4..52c1d1ba5 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -24,6 +24,7 @@ from tensor2tensor.data_generators import problem_hparams from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_alternative import tensorflow as tf