Currently, it's not convenient to define both seq2seq.Sampler and seq2seq.BasicDecoder in init, because we usually need to switch to different samplers between training and testing, but the training flag is unavailable in init.
class Foo(tf.keras.layers.Layer):
def __init__():
super(Foo, self).__init__()
def call(inputs, training=None):
if training:
sampler = tfa.seq2seq.TrainingSampler()
else:
sampler = tfa.seq2seq.InferenceSampler()
# Decoder
decoder_cell = tf.keras.layers.LSTMCell(num_units)
projection_layer = tf.keras.layers.Dense(num_outputs)
decoder = tfa.seq2seq.BasicDecoder(
decoder_cell, sampler, output_layer=projection_layer)
outputs, _, _ = decoder(
decoder_emb_inp,
initial_state=encoder_state,
sequence_length=decoder_lengths)
return outputs
A workaround is to define a sampler that can both be used during training and inference, with a training flag in its call method.
Currently, it's not convenient to define both seq2seq.Sampler and seq2seq.BasicDecoder in init, because we usually need to switch to different samplers between training and testing, but the training flag is unavailable in init.
A workaround is to define a sampler that can both be used during training and inference, with a training flag in its call method.