Skip to content

Commit

Permalink
Allow specification of sample_id_shape and sample_id_dtype in seq2seq…
Browse files Browse the repository at this point in the history
….BasicDecoder and add a new InferenceHelper.

PiperOrigin-RevId: 165019969
  • Loading branch information
adarob authored and tensorflower-gardener committed Aug 11, 2017
1 parent 37c54be commit e9a8d75
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 6 deletions.
1 change: 1 addition & 0 deletions tensorflow/contrib/seq2seq/__init__.py
Expand Up @@ -47,6 +47,7 @@
"FinalBeamSearchDecoderOutput",
"gather_tree",
"GreedyEmbeddingHelper",
"InferenceHelper",
"SampleEmbeddingHelper",
"ScheduledEmbeddingTrainingHelper",
"ScheduledOutputTrainingHelper",
Expand Down
166 changes: 166 additions & 0 deletions tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py
Expand Up @@ -23,14 +23,19 @@

from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import bernoulli
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top

Expand Down Expand Up @@ -500,5 +505,166 @@ def testStepWithScheduledOutputTrainingHelperWithNoSampling(
sampling_probability=0.0, use_next_input_layer=True,
use_auxiliary_inputs=True)

def testStepWithInferenceHelperCategorical(self):
batch_size = 5
vocabulary_size = 7
cell_depth = vocabulary_size
start_token = 0
end_token = 6

start_inputs = array_ops.one_hot(
np.ones(batch_size) * start_token,
vocabulary_size)

# The sample function samples categorically from the logits.
sample_fn = lambda x: categorical.Categorical(logits=x).sample()
# The next inputs are a one-hot encoding of the sampled labels.
next_inputs_fn = (
lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32))
end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token)

with self.test_session(use_gpu=True) as sess:
with variable_scope.variable_scope(
"testStepWithInferenceHelper",
initializer=init_ops.constant_initializer(0.01)):
cell = rnn_cell.LSTMCell(vocabulary_size)
helper = helper_py.InferenceHelper(
sample_fn, sample_shape=(), sample_dtype=dtypes.int32,
start_inputs=start_inputs, end_fn=end_fn,
next_inputs_fn=next_inputs_fn)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
basic_decoder.BasicDecoderOutput(cell_depth,
tensor_shape.TensorShape([])),
output_size)
self.assertEqual(
basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32),
output_dtype)

(first_finished, first_inputs, first_state) = my_decoder.initialize()
(step_outputs, step_state, step_next_inputs,
step_finished) = my_decoder.step(
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size

self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
self.assertEqual((batch_size,), step_outputs[1].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"batch_size": batch_size_t,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished
})

sample_ids = sess_results["step_outputs"].sample_id
self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
expected_step_finished = (sample_ids == end_token)
expected_step_next_inputs = np.zeros((batch_size, vocabulary_size))
expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0
self.assertAllEqual(expected_step_finished,
sess_results["step_finished"])
self.assertAllEqual(expected_step_next_inputs,
sess_results["step_next_inputs"])

def testStepWithInferenceHelperMultilabel(self):
batch_size = 5
vocabulary_size = 7
cell_depth = vocabulary_size
start_token = 0
end_token = 6

start_inputs = array_ops.one_hot(
np.ones(batch_size) * start_token,
vocabulary_size)

# The sample function samples independent bernoullis from the logits.
sample_fn = (
lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample())
# The next inputs are a one-hot encoding of the sampled labels.
next_inputs_fn = math_ops.to_float
end_fn = lambda sample_ids: sample_ids[:, end_token]

with self.test_session(use_gpu=True) as sess:
with variable_scope.variable_scope(
"testStepWithInferenceHelper",
initializer=init_ops.constant_initializer(0.01)):
cell = rnn_cell.LSTMCell(vocabulary_size)
helper = helper_py.InferenceHelper(
sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool,
start_inputs=start_inputs, end_fn=end_fn,
next_inputs_fn=next_inputs_fn)
my_decoder = basic_decoder.BasicDecoder(
cell=cell,
helper=helper,
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
output_size = my_decoder.output_size
output_dtype = my_decoder.output_dtype
self.assertEqual(
basic_decoder.BasicDecoderOutput(cell_depth, cell_depth),
output_size)
self.assertEqual(
basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool),
output_dtype)

(first_finished, first_inputs, first_state) = my_decoder.initialize()
(step_outputs, step_state, step_next_inputs,
step_finished) = my_decoder.step(
constant_op.constant(0), first_inputs, first_state)
batch_size_t = my_decoder.batch_size

self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple))
self.assertTrue(
isinstance(step_outputs, basic_decoder.BasicDecoderOutput))
self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape())
self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[0].get_shape())
self.assertEqual((batch_size, cell_depth), first_state[1].get_shape())
self.assertEqual((batch_size, cell_depth), step_state[0].get_shape())
self.assertEqual((batch_size, cell_depth), step_state[1].get_shape())

sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"batch_size": batch_size_t,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished
})

sample_ids = sess_results["step_outputs"].sample_id
self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
expected_step_finished = sample_ids[:, end_token]
expected_step_next_inputs = sample_ids.astype(np.float32)
self.assertAllEqual(expected_step_finished,
sess_results["step_finished"])
self.assertAllEqual(expected_step_next_inputs,
sess_results["step_next_inputs"])


if __name__ == "__main__":
test.main()
9 changes: 4 additions & 5 deletions tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
Expand Up @@ -23,7 +23,6 @@

from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_base
Expand Down Expand Up @@ -54,7 +53,7 @@ def __init__(self, cell, helper, initial_state, output_layer=None):
initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
The initial state of the RNNCell.
output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
Raises:
Expand Down Expand Up @@ -100,17 +99,17 @@ def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(),
sample_id=tensor_shape.TensorShape([]))
sample_id=self._helper.sample_ids_shape)

@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and int32 (the id)
# Return that structure and the sample_ids_dtype from the helper.
dtype = nest.flatten(self._initial_state)[0].dtype
return BasicDecoderOutput(
nest.map_structure(lambda _: dtype, self._rnn_output_size()),
dtypes.int32)
self._helper.sample_ids_dtype)

def initialize(self, name=None):
"""Initialize the decoder.
Expand Down

0 comments on commit e9a8d75

Please sign in to comment.