-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Define seq2seq.Sampler and seq2seq.BasicDecoder in __init__. #603
Comments
@qlzh727 @guillaumekln Can you help to take a look? |
Another common practice is to create two separate decoders one for training and one for inference inside the class Foo(tf.keras.layers.Layer):
def __init__():
super(Foo, self).__init__()
decoder_cell = tf.keras.layers.LSTMCell(num_units)
projection_layer = tf.keras.layers.Dense(num_outputs)
train_sampler = tfa.seq2seq.TrainingSampler()
self._train_decoder = tfa.seq2seq.BasicDecoder(
decoder_cell, sampler, output_layer=projection_layer)
inference_sampler = tfa.seq2seq.TrainingSampler()
self._inference_decoder = tfa.seq2seq.BasicDecoder(
decoder_cell, inference_sampler, output_layer=projection_layer)
def call(...):
if training is None:
training = tf.keras.backend.learning_phase()
if training:
# Training stuff
else:
# Inference stuff |
@kazemnejad Another question is how can we debug with eager mode? If we place sampler in init, we cannot reset decoder inputs for sampler (teacher forcing). |
The decoder itself re-initializes the sampler upon each invocation. It's also worth mentioning that Neither of |
I think the example from @kazemnejad should work nicely. In general, we recommend user to create instance they need in the init() since should the correct context (either eager in v2 or graph in v1). On the otherhand, call() will almost certain executed in graph context, which might lead to some unexpected behavior/error. By the way, if you want to debug in the eager mode, you could try model.compile(run_eagerly=True), which will then cause all the call() body to be run in eager context. Or you can directly invoke your Foo with the numpy data. |
@qlzh727 @kazemnejad Thank you for your advice! It works now. |
Currently, it's not convenient to define both seq2seq.Sampler and seq2seq.BasicDecoder in init, because we usually need to switch to different samplers between training and testing, but the training flag is unavailable in init.
A workaround is to define a sampler that can both be used during training and inference, with a training flag in its call method.
The text was updated successfully, but these errors were encountered: