/
basic_decoder.py
194 lines (169 loc) · 7.48 KB
/
basic_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A basic decoder that may sample to generate the next input."""
import collections
import tensorflow as tf
from tensorflow_addons.seq2seq import decoder
from tensorflow_addons.seq2seq import sampler as sampler_py
from tensorflow_addons.utils import keras_utils
from typeguard import typechecked
from typing import Optional
class BasicDecoderOutput(
collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))
):
"""Outputs of a `tfa.seq2seq.BasicDecoder` step.
Attributes:
rnn_output: The output for this step. If the `output_layer` argument
of `tfa.seq2seq.BasicDecoder` was set, it is the output of this layer, otherwise it
is the output of the RNN cell.
sample_id: The token IDs sampled for this step, as returned by the
`sampler` instance passed to `tfa.seq2seq.BasicDecoder`.
"""
pass
class BasicDecoder(decoder.BaseDecoder):
"""Basic sampling decoder for training and inference.
The `tfa.seq2seq.Sampler` instance passed as argument is responsible to sample from
the output distribution and produce the input for the next decoding step. The decoding
loop is implemented by the decoder in its `__call__` method.
Example using `tfa.seq2seq.TrainingSampler` for training:
>>> batch_size = 4
>>> max_time = 7
>>> hidden_size = 32
>>> embedding_size = 48
>>> input_vocab_size = 128
>>> output_vocab_size = 64
>>>
>>> embedding_layer = tf.keras.layers.Embedding(input_vocab_size, embedding_size)
>>> decoder_cell = tf.keras.layers.LSTMCell(hidden_size)
>>> sampler = tfa.seq2seq.TrainingSampler()
>>> output_layer = tf.keras.layers.Dense(output_vocab_size)
>>>
>>> decoder = tfa.seq2seq.BasicDecoder(decoder_cell, sampler, output_layer)
>>>
>>> input_ids = tf.random.uniform(
... [batch_size, max_time], maxval=input_vocab_size, dtype=tf.int64)
>>> input_lengths = tf.fill([batch_size], max_time)
>>> input_tensors = embedding_layer(input_ids)
>>> initial_state = decoder_cell.get_initial_state(input_tensors)
>>>
>>> output, state, lengths = decoder(
... input_tensors, sequence_length=input_lengths, initial_state=initial_state)
>>>
>>> logits = output.rnn_output
>>> logits.shape
TensorShape([4, 7, 64])
Example using `tfa.seq2seq.GreedyEmbeddingSampler` for inference:
>>> sampler = tfa.seq2seq.GreedyEmbeddingSampler(embedding_layer)
>>> decoder = tfa.seq2seq.BasicDecoder(
... decoder_cell, sampler, output_layer, maximum_iterations=10)
>>>
>>> initial_state = decoder_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
>>> start_tokens = tf.fill([batch_size], 1)
>>> end_token = 2
>>>
>>> output, state, lengths = decoder(
... None, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state)
>>>
>>> output.sample_id.shape
TensorShape([4, 10])
"""
@typechecked
def __init__(
self,
cell: tf.keras.layers.Layer,
sampler: sampler_py.Sampler,
output_layer: Optional[tf.keras.layers.Layer] = None,
**kwargs,
):
"""Initialize BasicDecoder.
Args:
cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
interface.
sampler: A `tfa.seq2seq.Sampler` instance.
output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e.,
`tf.keras.layers.Dense`. Optional layer to apply to the RNN output
prior to storing the result or sampling.
**kwargs: Other keyword arguments of `tfa.seq2seq.BaseDecoder`.
"""
keras_utils.assert_like_rnncell("cell", cell)
self.cell = cell
self.sampler = sampler
self.output_layer = output_layer
super().__init__(**kwargs)
def initialize(self, inputs, initial_state=None, **kwargs):
"""Initialize the decoder."""
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
self._cell_dtype = tf.nest.flatten(initial_state)[0].dtype
return self.sampler.initialize(inputs, **kwargs) + (initial_state,)
@property
def batch_size(self):
return self.sampler.batch_size
def _rnn_output_size(self):
size = tf.TensorShape(self.cell.output_size)
if self.output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = tf.nest.map_structure(
lambda s: tf.TensorShape([None]).concatenate(s), size
)
layer_output_shape = self.output_layer.compute_output_shape(
output_shape_with_unknown_batch
)
return tf.nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(), sample_id=self.sampler.sample_ids_shape
)
@property
def output_dtype(self):
# Assume the dtype of the cell is the output_size structure
# containing the input_state's first component's dtype.
# Return that structure and the sample_ids_dtype from the helper.
dtype = self._cell_dtype
return BasicDecoderOutput(
tf.nest.map_structure(lambda _: dtype, self._rnn_output_size()),
self.sampler.sample_ids_dtype,
)
def step(self, time, inputs, state, training=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
training: Python boolean.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
cell_outputs, cell_state = self.cell(inputs, state, training=training)
cell_state = tf.nest.pack_sequence_as(state, tf.nest.flatten(cell_state))
if self.output_layer is not None:
cell_outputs = self.output_layer(cell_outputs)
sample_ids = self.sampler.sample(
time=time, outputs=cell_outputs, state=cell_state
)
(finished, next_inputs, next_state) = self.sampler.next_inputs(
time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids
)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)