Skip to content
This repository was archived by the owner on Mar 11, 2026. It is now read-only.

Define seq2seq.Sampler and seq2seq.BasicDecoder in __init__. #603

@npuichigo

Description

@npuichigo

Currently, it's not convenient to define both seq2seq.Sampler and seq2seq.BasicDecoder in init, because we usually need to switch to different samplers between training and testing, but the training flag is unavailable in init.

class Foo(tf.keras.layers.Layer):
  def __init__():
      super(Foo, self).__init__()

  def call(inputs, training=None):
    if training:
      sampler = tfa.seq2seq.TrainingSampler()
    else:
      sampler = tfa.seq2seq.InferenceSampler()

   # Decoder
   decoder_cell = tf.keras.layers.LSTMCell(num_units)
   projection_layer = tf.keras.layers.Dense(num_outputs)
   decoder = tfa.seq2seq.BasicDecoder(
       decoder_cell, sampler, output_layer=projection_layer)

   outputs, _, _ = decoder(
       decoder_emb_inp,
       initial_state=encoder_state,
       sequence_length=decoder_lengths)

   return outputs

A workaround is to define a sampler that can both be used during training and inference, with a training flag in its call method.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions