Skip to content
This repository has been archived by the owner on Dec 11, 2023. It is now read-only.

Question: Using TrainingHelper vs GreedyEmbeddingHelper during training. #3

Closed
okuchaiev opened this issue Jul 13, 2017 · 15 comments
Closed
Assignees

Comments

@okuchaiev
Copy link

We are able to train simple seq2seq (our own code but using TF v1.2 contrib.seq2seq APIs) to perform simple tasks like string reversals (BLEU 100%).

However on De-En data our inference seems to stuck emitting one or two words. During training the model samples (using TrainingHelper) almost perfect translations. BUT inference keeps emitting same token or two per sentence. (We tried BasicDecoder with GreedyEmbeddingHelper and BeamSearchDecoder for inference).
Looking at evaluation loss on De-En (computed using BasicDecoder + GreedyEmbeddingHelper), it looks as if model really quickly overfits (eval loss shoots up and train loss goes down quickly).
When I try training using BasicDecoder + GreedyEmbeddingHelper, both training loss and eval loss are trending down (need few more hours to train though).
Also, on another (bigger and with longer sequences) toy task, test results after training using BasicDecoder + GreedyEmbeddingHelper are much better than after training using BasicDecoder + TrainingHelper. In all our experiments we use Bahdanau attention.

Hence my questions are:

  1. Why not use GreedyEmbeddingHelper during training? This would make decoder auto-regressive during training which is more similar to what happens during inference.
  2. Any ideas why TrainingHelper fails to work on bigger tasks? (looks like overfit and model only learned target language LM, but encoder and attention weights/gradients aren't zeros)
    @ebrevdo could you please comment on this?
@ebrevdo
Copy link
Contributor

ebrevdo commented Jul 13, 2017

This is generally a good (modeling) question for stack overflow; rather than a technical question about TF primitives. For future questions, i suggest asking there. For this one, @lmthang may have some suggestions.

@okuchaiev
Copy link
Author

okuchaiev commented Jul 14, 2017

@lmthang any advice on this? I agree that auto-regressive decoder during training is a modelling question. However, an inability to use TrainingHelper during training for "real" task seems like question about TF primitives to me.
Here is our code:

 def _build_decoder(self,
                     encoder_outputs,
                     enc_src_lengths,
                     tgt_inputs = None,
                     tgt_lengths = None,
                     GO_SYMBOL = 1,
                     END_SYMBOL = 2,
                     out_layer_activation = None):
    """
    Builds decoder part of the graph, for training and inference
    TODO: add param tensor shapes
    :param encoder_outputs:
    :param enc_src_lengths:
    :param tgt_inputs:
    :param tgt_lengths:
    :param GO_SYMBOL:
    :param END_SYMBOL:
    :param out_layer_activation:
    :return:
    """
    with tf.variable_scope("Decoder"):
      tgt_vocab_size = self.model_params['tgt_vocab_size']
      tgt_emb_size = self.model_params['tgt_emb_size']
      self._tgt_w = tf.get_variable(name='W_tgt_embedding',
                                    shape=[tgt_vocab_size, tgt_emb_size], dtype=getdtype())
      batch_size = self.model_params['batch_size']

      decoder_cell = create_rnn_cell(cell_type=self.model_params['decoder_cell_type'],
                                     cell_params={"num_units": self.model_params['decoder_cell_units']},
                                     num_layers=self.model_params['decoder_layers'],
                                     dp_input_keep_prob=self.model_params['decoder_dp_input_keep_prob'],
                                     dp_output_keep_prob=self.model_params['decoder_dp_output_keep_prob'],
                                     residual_connections=self.model_params['decoder_use_skip_connections'])

      output_layer = layers_core.Dense(tgt_vocab_size, use_bias=False,
                                       activation = out_layer_activation)

      def attn_decoder_custom_fn(inputs, attention):
          # to make shapes equal for skip connections
          if self.model_params['decoder_use_skip_connections']:
              input_layer = layers_core.Dense(self.model_params['attention_layer_size'], dtype=getdtype())
              return input_layer(tf.concat([inputs, attention], -1))
          else:
              return tf.concat([inputs, attention], -1)

      if self._mode == "infer":
        if self._decoder_type == "beam_search":
          self._length_penalty_weight = 1.0 if "length_penalty" not in self.model_params else self.model_params[
            "length_penalty"]
          # beam_width of 1 should be same as argmax decoder
          self._beam_width = 1 if "beam_width" not in self.model_params else self.model_params["beam_width"]
          tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self._beam_width)
          tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(enc_src_lengths, multiplier=self._beam_width)
          attention_mechanism = self._build_attention(tiled_enc_outputs, tiled_enc_src_lengths)
          attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                      attention_mechanism=attention_mechanism,
                                                                      cell_input_fn=attn_decoder_custom_fn)
          batch_size_tensor = tf.constant(batch_size)
          decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell=attentive_decoder_cell,
            embedding=self._tgt_w,
            start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
            end_token=END_SYMBOL,
            initial_state=attentive_decoder_cell.zero_state(dtype=getdtype(),
                                                            batch_size=batch_size_tensor * self._beam_width),
            beam_width=self._beam_width,
            output_layer=output_layer,
            length_penalty_weight=self._length_penalty_weight)
        else:
          attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
          attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                      attention_mechanism=attention_mechanism,
                                                                      cell_input_fn=attn_decoder_custom_fn)
          helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding=self._tgt_w,
            start_tokens=tf.fill([batch_size], GO_SYMBOL),
            #start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
            end_token=END_SYMBOL)
          decoder = tf.contrib.seq2seq.BasicDecoder(
            cell=attentive_decoder_cell,
            helper=helper,
            initial_state=attentive_decoder_cell.zero_state(batch_size=batch_size, dtype=getdtype()),
            output_layer=output_layer)
      elif self._mode == "train":
        attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
        attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
                                                                    attention_mechanism=attention_mechanism)
        input_vectors = tf.nn.embedding_lookup(self._tgt_w, tgt_inputs)
        helper = tf.contrib.seq2seq.TrainingHelper(
          inputs = input_vectors,
          sequence_length = tgt_lengths)

        decoder = tf.contrib.seq2seq.BasicDecoder(
          cell=attentive_decoder_cell,
          helper=helper,
          output_layer=output_layer,
          initial_state=attentive_decoder_cell.zero_state(batch_size, dtype=getdtype()))
      else:
        raise NotImplementedError("Unknown mode")

      final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
        decoder = decoder,
        #output_time_major=True,
        impute_finished=False,
        maximum_iterations=tf.reduce_max(tgt_lengths) if self._mode == 'train' else self.tgt_max_size)
  return final_outputs, final_state, final_sequence_lengths

@okuchaiev
Copy link
Author

okuchaiev commented Jul 14, 2017

Closing this as I think I found the issue.
The problem was that we were training so that decoder was supposed to start by predicting GO_SYMBOL. Since inference APIs are such that it is expected to be fed, our inference did not work correctly.

@ebrevdo I think it is a good idea to require TrainingHelper to take GO_SYMBOL , this way it would be more similar to inference helpers and we wouldn't get confused. Or inference helpers/decoders should make it optional.

@tocab
Copy link

tocab commented Jul 29, 2017

I am on the same problem, training results are very good with TrainingHelper, but the inference with GreedyEmbeddingHelper delivers poor performance. It produced only the GO-Symbol which defined in the GreedyEmbeddingHelper for max_iterations-times on every input.

@okuchaiev How did you solve the problem?

@RylanSchaeffer
Copy link

Can someone clarify how GO symbols are supposed to be used during training as opposed to during inference?

@RylanSchaeffer
Copy link

@tocab , did you find a solution to your problem?

@okuchaiev
Copy link
Author

@tocab If training example is: 'GO a b c EOS' => 'GO 1 2 3 EOS'. Then make sure that:

  1. 'GO a b c EOS' is fed to the encoder
  2. 'GO 1 2 3 EOS' is fed to the decoder
  3. targets are '1, 2, 3, EOS, PAD'
  4. loss is masked for positions >= targets.indexof(PAD)

My guess is that you have something wrong is either 3 or 4.

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 13, 2017

@okuchaiev , thanks for responding! Can you clarify which of 1 through 4 applying for training and which apply for validation? I'm specifically confused by the difference between 2 and 3 - when is one used versus the other?

My guess is that 1, 2, and 4 apply during training, and then during validation/inference, one uses 1, 2
if known (for calculating the loss) and 4, but I'm not sure how 3 comes into play.

@okuchaiev
Copy link
Author

okuchaiev commented Aug 13, 2017

@RylanSchaeffer all 1,2,3, and 4 apply during training. Let me add some terminology to clarify:
a) 'GO a b c EOS' is encoder input
b) 'GO 1 2 3 EOS' is decoder input
c) '1, 2, 3, EOS, PAD' is target, or expected decoder output that is fed to the loss function only, to compute loss

During training you have all (a), (b), and (c) ((c) is just (b) shifted by one position).
During inference you only have (a). (b) and (c) do not exist during inference - the task is to predict (c)
That is you feed (a) to encoder, and GreedyEmbedingHelper (or similar) will require you to provide GO symbol to decoder. Then dynamic_decode will construct decoder output in auto-regressive manner.

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 13, 2017

@okuchaiev , thank you for clarifying! How did you find this out? I feel like I must have misread some critical documentation somewhere.

@tocab
Copy link

tocab commented Aug 13, 2017

@okuchaiev Thank you for your answer. But I trained my model finally in another way:
a) Encoder input: 'a b c EOS PAD'
b) Decoder input: 'GO 1 2 3 PAD'
c) Target: '1 2 3 EOS PAD'

I think for a), the encoder doesn't need a GO-Symbol. For b) and c) I think the EOS is only part of the output, so the decoder learn to predict a EOS after the last word ('3' in this case).

With this setup I reached a good performance for training and inference and solved the problem which i had before.

@lmthang
Copy link
Contributor

lmthang commented Aug 13, 2017

Yes, the encoder doesn't need GO & you can merge EOS & PAD (as long as sequence lengths are correct). Here are an example with a batch of 2 sentences for training:
(a) Encoder inputs (encoder lengths = [3, 2]):
a b c EOS
d e EOS EOS
(b) Decoder inputs (decoder lengths = [4, 3]):
GO 1 2 3
GO 4 5 EOS
(c) Decoder outputs (shift-by-1 of decoder inputs):
1 2 3 EOS
4 5 EOS EOS
(the first EOS is part of the loss)

During inference: we only have (a) + GO symbol fed to GreedyEmbeddingHelper, so it's important to make sure this GO is the same as the GO in (b).

Does it help clarify? I'd love to hear feedback on how to improve the tutorial so people won't make this mistake.

Thanks for the discussion!

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 13, 2017

@lmthang , thank you for clarifying!

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 13, 2017

@lmthang quick question - the Decoder inputs do not necessarily need to end with EOS? In your example (b), the first of the two Decoder inputs does not.

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 17, 2017

@lmthang , some suggestions regarding the tutorial:

  1. I think adding a concrete example of the format of encoder_inputs, decoder_inputs and decoder_outputs (like you did above) would be helpful.
  2. The seq2seq module uses the variables alignments and alignment_history, but the tutorial doesn't make any mention of alignment beyond the picture. I think explaining how "attention" and "alignment" differ would be helpful in understanding the module.
  3. I personally would appreciate an explanation of why the terminology of memory networks was adopted as I was confused by terms like "memory depth" while reading the seq2seq code, but that might be tangential to the tutorial.

gundamMC added a commit to gundamMC/animius that referenced this issue Sep 8, 2018
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>).
See tensorflow/nmt#3
gundamMC pushed a commit to gundamMC/animius that referenced this issue Sep 8, 2018
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>).
See tensorflow/nmt#3
xiaoyang-sde pushed a commit to gundamMC/animius that referenced this issue Apr 27, 2020
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>).
See tensorflow/nmt#3
xiaoyang-sde pushed a commit to gundamMC/animius that referenced this issue Apr 27, 2020
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>).
See tensorflow/nmt#3
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants