Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define seq2seq.Sampler and seq2seq.BasicDecoder in __init__. #603

Closed
npuichigo opened this issue Oct 17, 2019 · 6 comments
Closed

Define seq2seq.Sampler and seq2seq.BasicDecoder in __init__. #603

npuichigo opened this issue Oct 17, 2019 · 6 comments
Labels

Comments

@npuichigo
Copy link

npuichigo commented Oct 17, 2019

Currently, it's not convenient to define both seq2seq.Sampler and seq2seq.BasicDecoder in init, because we usually need to switch to different samplers between training and testing, but the training flag is unavailable in init.

class Foo(tf.keras.layers.Layer):
  def __init__():
      super(Foo, self).__init__()

  def call(inputs, training=None):
    if training:
      sampler = tfa.seq2seq.TrainingSampler()
    else:
      sampler = tfa.seq2seq.InferenceSampler()

   # Decoder
   decoder_cell = tf.keras.layers.LSTMCell(num_units)
   projection_layer = tf.keras.layers.Dense(num_outputs)
   decoder = tfa.seq2seq.BasicDecoder(
       decoder_cell, sampler, output_layer=projection_layer)

   outputs, _, _ = decoder(
       decoder_emb_inp,
       initial_state=encoder_state,
       sequence_length=decoder_lengths)

   return outputs

A workaround is to define a sampler that can both be used during training and inference, with a training flag in its call method.

@npuichigo
Copy link
Author

npuichigo commented Oct 18, 2019

@qlzh727 @guillaumekln Can you help to take a look?

@kazemnejad
Copy link
Contributor

kazemnejad commented Oct 18, 2019

Another common practice is to create two separate decoders one for training and one for inference inside the __init__(..) method, and let the decoder_cell and projection_layer to be shared between them, Here is an example:

class Foo(tf.keras.layers.Layer):
  def __init__():
      super(Foo, self).__init__()

      decoder_cell = tf.keras.layers.LSTMCell(num_units)
      projection_layer = tf.keras.layers.Dense(num_outputs)

      train_sampler = tfa.seq2seq.TrainingSampler()
      self._train_decoder = tfa.seq2seq.BasicDecoder(
                decoder_cell, sampler, output_layer=projection_layer)

      inference_sampler = tfa.seq2seq.TrainingSampler()
      self._inference_decoder = tfa.seq2seq.BasicDecoder(
               decoder_cell, inference_sampler, output_layer=projection_layer)

  def call(...):
      if training is None:
              training = tf.keras.backend.learning_phase()
       
      if training:
             # Training stuff
      else:
             # Inference stuff

@npuichigo
Copy link
Author

npuichigo commented Oct 21, 2019

@kazemnejad Another question is how can we debug with eager mode? If we place sampler in init, we cannot reset decoder inputs for sampler (teacher forcing).

@kazemnejad
Copy link
Contributor

kazemnejad commented Oct 21, 2019

The decoder itself re-initializes the sampler upon each invocation. It's also worth mentioning that Neither of Decoder or Sampler accepts any data input/tensor in the init method; instead, they receive the inputs on their call method.

@qlzh727
Copy link
Member

qlzh727 commented Oct 21, 2019

I think the example from @kazemnejad should work nicely. In general, we recommend user to create instance they need in the init() since should the correct context (either eager in v2 or graph in v1). On the otherhand, call() will almost certain executed in graph context, which might lead to some unexpected behavior/error.

By the way, if you want to debug in the eager mode, you could try model.compile(run_eagerly=True), which will then cause all the call() body to be run in eager context. Or you can directly invoke your Foo with the numpy data.

@npuichigo
Copy link
Author

npuichigo commented Oct 22, 2019

@qlzh727 @kazemnejad Thank you for your advice! It works now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants