From e9a8d75bc470fb410fc54bf78fb4c9ff5d3bdff6 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Fri, 11 Aug 2017 13:50:47 -0700 Subject: [PATCH] Allow specification of sample_id_shape and sample_id_dtype in seq2seq.BasicDecoder and add a new InferenceHelper. PiperOrigin-RevId: 165019969 --- tensorflow/contrib/seq2seq/__init__.py | 1 + .../python/kernel_tests/basic_decoder_test.py | 166 ++++++++++++++++++ .../seq2seq/python/ops/basic_decoder.py | 9 +- .../contrib/seq2seq/python/ops/helper.py | 109 +++++++++++- 4 files changed, 279 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py index c4abef268b0bf5..a7279bc339d8a4 100644 --- a/tensorflow/contrib/seq2seq/__init__.py +++ b/tensorflow/contrib/seq2seq/__init__.py @@ -47,6 +47,7 @@ "FinalBeamSearchDecoderOutput", "gather_tree", "GreedyEmbeddingHelper", + "InferenceHelper", "SampleEmbeddingHelper", "ScheduledEmbeddingTrainingHelper", "ScheduledOutputTrainingHelper", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index c99562555a1a51..2cd2726a6facb3 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -23,14 +23,19 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.contrib.seq2seq.python.ops import basic_decoder + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.distributions import bernoulli +from tensorflow.python.ops.distributions import categorical from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -500,5 +505,166 @@ def testStepWithScheduledOutputTrainingHelperWithNoSampling( sampling_probability=0.0, use_next_input_layer=True, use_auxiliary_inputs=True) + def testStepWithInferenceHelperCategorical(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size) * start_token, + vocabulary_size) + + # The sample function samples categorically from the logits. + sample_fn = lambda x: categorical.Categorical(logits=x).sample() + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = ( + lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) + end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token) + + with self.test_session(use_gpu=True) as sess: + with variable_scope.variable_scope( + "testStepWithInferenceHelper", + initializer=init_ops.constant_initializer(0.01)): + cell = rnn_cell.LSTMCell(vocabulary_size) + helper = helper_py.InferenceHelper( + sample_fn, sample_shape=(), sample_dtype=dtypes.int32, + start_inputs=start_inputs, end_fn=end_fn, + next_inputs_fn=next_inputs_fn) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = sess_results["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = np.zeros((batch_size, vocabulary_size)) + expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0 + self.assertAllEqual(expected_step_finished, + sess_results["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + sess_results["step_next_inputs"]) + + def testStepWithInferenceHelperMultilabel(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size) * start_token, + vocabulary_size) + + # The sample function samples independent bernoullis from the logits. + sample_fn = ( + lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample()) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = math_ops.to_float + end_fn = lambda sample_ids: sample_ids[:, end_token] + + with self.test_session(use_gpu=True) as sess: + with variable_scope.variable_scope( + "testStepWithInferenceHelper", + initializer=init_ops.constant_initializer(0.01)): + cell = rnn_cell.LSTMCell(vocabulary_size) + helper = helper_py.InferenceHelper( + sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool, + start_inputs=start_inputs, end_fn=end_fn, + next_inputs_fn=next_inputs_fn) + my_decoder = basic_decoder.BasicDecoder( + cell=cell, + helper=helper, + initial_state=cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size)) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool), + output_dtype) + + (first_finished, first_inputs, first_state) = my_decoder.initialize() + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + sess.run(variables.global_variables_initializer()) + sess_results = sess.run({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = sess_results["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = sample_ids[:, end_token] + expected_step_next_inputs = sample_ids.astype(np.float32) + self.assertAllEqual(expected_step_finished, + sess_results["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + sess_results["step_next_inputs"]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 8ae175b6b59a88..c7c4182f0d9a17 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -23,7 +23,6 @@ from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base @@ -54,7 +53,7 @@ def __init__(self, cell, helper, initial_state, output_layer=None): initial_state: A (possibly nested tuple of...) tensors and TensorArrays. The initial state of the RNNCell. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior + `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. Raises: @@ -100,17 +99,17 @@ def output_size(self): # Return the cell output and the id return BasicDecoderOutput( rnn_output=self._rnn_output_size(), - sample_id=tensor_shape.TensorShape([])) + sample_id=self._helper.sample_ids_shape) @property def output_dtype(self): # Assume the dtype of the cell is the output_size structure # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) + # Return that structure and the sample_ids_dtype from the helper. dtype = nest.flatten(self._initial_state)[0].dtype return BasicDecoderOutput( nest.map_structure(lambda _: dtype, self._rnn_output_size()), - dtypes.int32) + self._helper.sample_ids_dtype) def initialize(self, name=None): """Initialize the decoder. diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index a716dcba738ef4..c1682de041107c 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -26,6 +26,7 @@ from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -45,6 +46,7 @@ "CustomHelper", "ScheduledEmbeddingTrainingHelper", "ScheduledOutputTrainingHelper", + "InferenceHelper", ] _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access @@ -71,6 +73,22 @@ def batch_size(self): """ raise NotImplementedError("batch_size has not been implemented") + @abc.abstractproperty + def sample_ids_shape(self): + """Shape of tensor returned by `sample`, excluding the batch dimension. + + Returns a `TensorShape`. + """ + raise NotImplementedError("sample_ids_shape has not been implemented") + + @abc.abstractproperty + def sample_ids_dtype(self): + """DType of tensor returned by `sample`. + + Returns a DType. + """ + raise NotImplementedError("sample_ids_dtype has not been implemented") + @abc.abstractmethod def initialize(self, name=None): """Returns `(initial_finished, initial_inputs)`.""" @@ -90,7 +108,8 @@ def next_inputs(self, time, outputs, state, sample_ids, name=None): class CustomHelper(Helper): """Base abstract class that allows the user to customize sampling.""" - def __init__(self, initialize_fn, sample_fn, next_inputs_fn): + def __init__(self, initialize_fn, sample_fn, next_inputs_fn, + sample_ids_shape=None, sample_ids_dtype=None): """Initializer. Args: @@ -100,11 +119,17 @@ def __init__(self, initialize_fn, sample_fn, next_inputs_fn): and emits tensor `sample_ids`. next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` and emits `(finished, next_inputs, next_state)`. + sample_ids_shape: Either a list of integers, or a 1-D Tensor of type + `int32`, the shape of each value in the `sample_ids` batch. Defaults to + a scalar. + sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32. """ self._initialize_fn = initialize_fn self._sample_fn = sample_fn self._next_inputs_fn = next_inputs_fn self._batch_size = None + self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) + self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 @property def batch_size(self): @@ -112,6 +137,14 @@ def batch_size(self): raise ValueError("batch_size accessed before initialize was called") return self._batch_size + @property + def sample_ids_shape(self): + return self._sample_ids_shape + + @property + def sample_ids_dtype(self): + return self._sample_ids_dtype + def initialize(self, name=None): with ops.name_scope(name, "%sInitialize" % type(self).__name__): (finished, next_inputs) = self._initialize_fn() @@ -172,6 +205,14 @@ def __init__(self, inputs, sequence_length, time_major=False, name=None): def batch_size(self): return self._batch_size + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + def initialize(self, name=None): with ops.name_scope(name, "TrainingHelperInitialize"): finished = math_ops.equal(0, self._sequence_length) @@ -485,6 +526,14 @@ def __init__(self, embedding, start_tokens, end_token): def batch_size(self): return self._batch_size + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + def initialize(self, name=None): finished = array_ops.tile([False], [self._batch_size]) return (finished, self._start_inputs) @@ -562,3 +611,61 @@ def sample(self, time, outputs, state, name=None): sample_ids = sample_id_sampler.sample(seed=self._seed) return sample_ids + + +class InferenceHelper(Helper): + """A helper to use during inference with a custom sampling function.""" + + def __init__(self, sample_fn, sample_shape, sample_dtype, + start_inputs, end_fn, next_inputs_fn=None): + """Initializer. + + Args: + sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`. + sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`, + the shape of the each sample in the batch returned by `sample_fn`. + sample_dtype: the dtype of the sample returned by `sample_fn`. + start_inputs: The initial batch of inputs. + end_fn: A callable that takes `sample_ids` and emits a `bool` vector + shaped `[batch_size]` indicating whether each sample is an end token. + next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns + the next batch of inputs. If not provided, `sample_ids` is used as the + next batch of inputs. + """ + self._sample_fn = sample_fn + self._end_fn = end_fn + self._sample_shape = tensor_shape.TensorShape(sample_shape) + self._sample_dtype = sample_dtype + self._next_inputs_fn = next_inputs_fn + self._batch_size = array_ops.shape(start_inputs)[0] + self._start_inputs = ops.convert_to_tensor( + start_inputs, name="start_inputs") + + @property + def batch_size(self): + return self._batch_size + + @property + def sample_ids_shape(self): + return self._sample_shape + + @property + def sample_ids_dtype(self): + return self._sample_dtype + + def initialize(self, name=None): + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self._start_inputs) + + def sample(self, time, outputs, state, name=None): + del time, state # unused by sample + return self._sample_fn(outputs) + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + del time, outputs # unused by next_inputs + if self._next_inputs_fn is None: + next_inputs = sample_ids + else: + next_inputs = self._next_inputs_fn(sample_ids) + finished = self._end_fn(sample_ids) + return (finished, next_inputs, state)