Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for masked input in TrainingSampler #546

Merged
111 changes: 111 additions & 0 deletions tensorflow_addons/seq2seq/basic_decoder_test.py
Expand Up @@ -122,6 +122,102 @@ def testStepWithTrainingHelperOutputLayer(self, use_output_layer):
np.argmax(eval_result["step_outputs"].rnn_output, -1),
eval_result["step_outputs"].sample_id)

@parameterized.named_parameters(("sequence_length_only", False),
("mask_only", True), ("no_mask", None))
def testStepWithTrainingHelperMaskedInput(self, use_mask):
batch_size = 5
max_time = 8
sequence_length = [max_time] * batch_size if use_mask is None \
else [3, 4, 3, 1, 0]
sequence_length = np.array(sequence_length, dtype=np.int32)
mask = [[True] * l + [False] * (max_time - l) for l in sequence_length]
input_depth = 7
cell_depth = 10
output_layer_depth = 3

with self.cached_session(use_gpu=True):
inputs = np.random.randn(batch_size, max_time,
input_depth).astype(np.float32)
input_t = tf.constant(inputs)
cell = tf.keras.layers.LSTMCell(cell_depth)
sampler = sampler_py.TrainingSampler(time_major=False)
output_layer = tf.keras.layers.Dense(
output_layer_depth, use_bias=False)
expected_output_depth = output_layer_depth
initial_state = cell.get_initial_state(
batch_size=batch_size, dtype=tf.float32)
my_decoder = basic_decoder.BasicDecoder(
cell=cell, sampler=sampler, output_layer=output_layer)

if use_mask is None:
(first_finished, first_inputs,
first_state) = my_decoder.initialize(
input_t, initial_state=initial_state)
elif use_mask:
(first_finished, first_inputs,
first_state) = my_decoder.initialize(
input_t, initial_state=initial_state, mask=mask)
else:
(first_finished, first_inputs,
first_state) = my_decoder.initialize(
input_t,
initial_state=initial_state,
sequence_length=sequence_length)

output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
basic_decoder.BasicDecoderOutput(expected_output_depth,
tf.TensorShape([])),
output_size)
self.assertEqual(
basic_decoder.BasicDecoderOutput(tf.float32, tf.int32),
output_dtype)

(step_outputs, step_state, step_next_inputs,
step_finished) = my_decoder.step(
tf.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size

self.assertLen(first_state, 2)
self.assertLen(step_state, 2)
self.assertIsInstance(step_outputs,
basic_decoder.BasicDecoderOutput)
self.assertEqual((batch_size, expected_output_depth),
step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
self.assertEqual((batch_size, cell_depth),
first_state[0].get_shape())
self.assertEqual((batch_size, cell_depth),
first_state[1].get_shape())
self.assertEqual((batch_size, cell_depth),
step_state[0].get_shape())
self.assertEqual((batch_size, cell_depth),
step_state[1].get_shape())

self.assertLen(output_layer.variables, 1)

eval_result = self.evaluate({
"batch_size": batch_size_t,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished
})

self.assertAllEqual(sequence_length == 0,
eval_result["first_finished"])
self.assertAllEqual((np.maximum(sequence_length - 1, 0) == 0),
eval_result["step_finished"])
self.assertEqual(output_dtype.sample_id,
eval_result["step_outputs"].sample_id.dtype)
self.assertAllEqual(
np.argmax(eval_result["step_outputs"].rnn_output, -1),
eval_result["step_outputs"].sample_id)

def testStepWithGreedyEmbeddingHelper(self):
batch_size = 5
vocabulary_size = 7
Expand Down Expand Up @@ -704,6 +800,21 @@ def testBasicDecoderWithAttentionWrapper(self):
decoder = basic_decoder.BasicDecoder(
cell, sampler, output_layer=output_layer)

def testRightPaddedSequenceAssertion(self):
right_padded_sequence = [[True, True, False, False],
[True, True, True, False]]
left_padded_sequence = [[False, False, True, True],
[False, True, True, True]]

assertion = sampler_py._check_sequence_is_right_padded(
right_padded_sequence, False)
self.evaluate(assertion)

with self.assertRaises(tf.errors.InvalidArgumentError):
assertion = sampler_py._check_sequence_is_right_padded(
left_padded_sequence, False)
self.evaluate(assertion)


if __name__ == "__main__":
tf.test.main()
88 changes: 72 additions & 16 deletions tensorflow_addons/seq2seq/sampler.py
Expand Up @@ -181,7 +181,8 @@ def __init__(self, time_major=False):
major. If `False` (default), they are assumed to be batch major.

Raises:
ValueError: if `sequence_length` is not a 1D tensor.
ValueError: if `sequence_length` is not a 1D tensor or `mask` is
not a 2D boolean tensor.
"""
self.time_major = time_major
self._batch_size = None
Expand All @@ -201,12 +202,13 @@ def sample_ids_shape(self):
def sample_ids_dtype(self):
return tf.int32

def initialize(self, inputs, sequence_length=None):
def initialize(self, inputs, sequence_length=None, mask=None):
"""Initialize the TrainSampler.

Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
mask: A boolean 2D tensor.

Returns:
(finished, next_inputs), a tuple of two items. The first item is a
Expand All @@ -218,21 +220,48 @@ def initialize(self, inputs, sequence_length=None):
if not self.time_major:
inputs = tf.nest.map_structure(_transpose_batch_time, inputs)

self._batch_size = tf.shape(tf.nest.flatten(inputs)[0])[1]

self.input_tas = tf.nest.map_structure(_unstack_ta, inputs)
if sequence_length is None:
raise ValueError("sequence_length is required for TrainingSampler")
self.sequence_length = tf.convert_to_tensor(
sequence_length, name="sequence_length")
if self.sequence_length.get_shape().ndims != 1:
raise ValueError(
"Expected sequence_length to be vector, but received shape: %s"
% self.sequence_length.get_shape())
if sequence_length is not None and mask is not None:
raise ValueError("sequence_length and mask can't be provided "
"at the same time.")
if sequence_length is not None:
self.sequence_length = tf.convert_to_tensor(
sequence_length, name="sequence_length")
if self.sequence_length.get_shape().ndims != 1:
raise ValueError(
"Expected sequence_length to be vector, but received "
"shape: %s" % self.sequence_length.get_shape())
elif mask is not None:
mask = tf.convert_to_tensor(mask)
if mask.get_shape().ndims != 2:
raise ValueError(
"Expected mask to a 2D tensor, but received shape: %s" %
mask)
if not mask.dtype.is_bool:
raise ValueError(
"Expected mask to be a boolean tensor, but received "
"dtype: %s" % repr(mask.dtype))

axis = 1 if not self.time_major else 0
with tf.control_dependencies(
[_check_sequence_is_right_padded(mask, self.time_major)] # pylint: disable=bad-continuation
):
self.sequence_length = tf.math.reduce_sum(
tf.cast(mask, tf.int32), axis=axis, name="sequence_length")
else:
# As the input tensor has been converted to time major,
# the maximum sequence length should be inferred from
# the first dimension.
max_seq_len = tf.shape(tf.nest.flatten(inputs)[0])[0]
kazemnejad marked this conversation as resolved.
Show resolved Hide resolved
self.sequence_length = tf.fill([self.batch_size],
max_seq_len,
name="sequence_length")

self.zero_inputs = tf.nest.map_structure(
lambda inp: tf.zeros_like(inp[0, :]), inputs)

self._batch_size = tf.size(self.sequence_length)

finished = tf.equal(0, self.sequence_length)
all_finished = tf.reduce_all(finished)
next_inputs = tf.cond(
Expand Down Expand Up @@ -305,7 +334,11 @@ def __init__(self,
super(ScheduledEmbeddingTrainingSampler,
self).__init__(time_major=time_major)

def initialize(self, inputs, sequence_length=None, embedding=None):
def initialize(self,
inputs,
sequence_length=None,
mask=None,
embedding=None):
if self.embedding_fn is None:
if embedding is None:
raise ValueError(
Expand All @@ -314,7 +347,7 @@ def initialize(self, inputs, sequence_length=None, embedding=None):
self.embedding_fn = (
lambda ids: tf.nn.embedding_lookup(embedding, ids))
return super(ScheduledEmbeddingTrainingSampler, self).initialize(
inputs, sequence_length=sequence_length)
inputs, sequence_length=sequence_length, mask=mask)

def sample(self, time, outputs, state):
del state
Expand Down Expand Up @@ -397,7 +430,11 @@ def __init__(self,
super(ScheduledOutputTrainingSampler,
self).__init__(time_major=time_major)

def initialize(self, inputs, sequence_length=None, auxiliary_inputs=None):
def initialize(self,
inputs,
sequence_length=None,
mask=None,
auxiliary_inputs=None):
if auxiliary_inputs is None:
maybe_concatenated_inputs = inputs
else:
Expand All @@ -415,7 +452,9 @@ def initialize(self, inputs, sequence_length=None, auxiliary_inputs=None):
self._auxiliary_input_tas = None

return super(ScheduledOutputTrainingSampler, self).initialize(
maybe_concatenated_inputs, sequence_length=sequence_length)
maybe_concatenated_inputs,
sequence_length=sequence_length,
mask=mask)

def sample(self, time, outputs, state):
del state
Expand Down Expand Up @@ -759,3 +798,20 @@ def _unstack_ta(inp):
dtype=inp.dtype,
size=tf.shape(inp)[0],
element_shape=inp.get_shape()[1:]).unstack(inp)


def _check_sequence_is_right_padded(mask, time_major):
"""Returns an Assert operation checking that if the mask tensor is right
padded."""
if time_major:
mask = tf.transpose(mask)
sequence_length = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
max_seq_length = tf.shape(mask)[1]
right_padded_mask = tf.sequence_mask(
sequence_length, maxlen=max_seq_length, dtype=tf.bool)
all_equal = tf.math.equal(mask, right_padded_mask)

condition = tf.math.reduce_all(all_equal)
error_message = "The input sequence should be right padded."

return tf.Assert(condition, [error_message])