Question: Using TrainingHelper vs GreedyEmbeddingHelper during training. #3
Comments
This is generally a good (modeling) question for stack overflow; rather than a technical question about TF primitives. For future questions, i suggest asking there. For this one, @lmthang may have some suggestions. |
@lmthang any advice on this? I agree that auto-regressive decoder during training is a modelling question. However, an inability to use TrainingHelper during training for "real" task seems like question about TF primitives to me. def _build_decoder(self,
encoder_outputs,
enc_src_lengths,
tgt_inputs = None,
tgt_lengths = None,
GO_SYMBOL = 1,
END_SYMBOL = 2,
out_layer_activation = None):
"""
Builds decoder part of the graph, for training and inference
TODO: add param tensor shapes
:param encoder_outputs:
:param enc_src_lengths:
:param tgt_inputs:
:param tgt_lengths:
:param GO_SYMBOL:
:param END_SYMBOL:
:param out_layer_activation:
:return:
"""
with tf.variable_scope("Decoder"):
tgt_vocab_size = self.model_params['tgt_vocab_size']
tgt_emb_size = self.model_params['tgt_emb_size']
self._tgt_w = tf.get_variable(name='W_tgt_embedding',
shape=[tgt_vocab_size, tgt_emb_size], dtype=getdtype())
batch_size = self.model_params['batch_size']
decoder_cell = create_rnn_cell(cell_type=self.model_params['decoder_cell_type'],
cell_params={"num_units": self.model_params['decoder_cell_units']},
num_layers=self.model_params['decoder_layers'],
dp_input_keep_prob=self.model_params['decoder_dp_input_keep_prob'],
dp_output_keep_prob=self.model_params['decoder_dp_output_keep_prob'],
residual_connections=self.model_params['decoder_use_skip_connections'])
output_layer = layers_core.Dense(tgt_vocab_size, use_bias=False,
activation = out_layer_activation)
def attn_decoder_custom_fn(inputs, attention):
# to make shapes equal for skip connections
if self.model_params['decoder_use_skip_connections']:
input_layer = layers_core.Dense(self.model_params['attention_layer_size'], dtype=getdtype())
return input_layer(tf.concat([inputs, attention], -1))
else:
return tf.concat([inputs, attention], -1)
if self._mode == "infer":
if self._decoder_type == "beam_search":
self._length_penalty_weight = 1.0 if "length_penalty" not in self.model_params else self.model_params[
"length_penalty"]
# beam_width of 1 should be same as argmax decoder
self._beam_width = 1 if "beam_width" not in self.model_params else self.model_params["beam_width"]
tiled_enc_outputs = tf.contrib.seq2seq.tile_batch(encoder_outputs, multiplier=self._beam_width)
tiled_enc_src_lengths = tf.contrib.seq2seq.tile_batch(enc_src_lengths, multiplier=self._beam_width)
attention_mechanism = self._build_attention(tiled_enc_outputs, tiled_enc_src_lengths)
attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
attention_mechanism=attention_mechanism,
cell_input_fn=attn_decoder_custom_fn)
batch_size_tensor = tf.constant(batch_size)
decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=attentive_decoder_cell,
embedding=self._tgt_w,
start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
end_token=END_SYMBOL,
initial_state=attentive_decoder_cell.zero_state(dtype=getdtype(),
batch_size=batch_size_tensor * self._beam_width),
beam_width=self._beam_width,
output_layer=output_layer,
length_penalty_weight=self._length_penalty_weight)
else:
attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
attention_mechanism=attention_mechanism,
cell_input_fn=attn_decoder_custom_fn)
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding=self._tgt_w,
start_tokens=tf.fill([batch_size], GO_SYMBOL),
#start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
end_token=END_SYMBOL)
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=attentive_decoder_cell,
helper=helper,
initial_state=attentive_decoder_cell.zero_state(batch_size=batch_size, dtype=getdtype()),
output_layer=output_layer)
elif self._mode == "train":
attention_mechanism = self._build_attention(encoder_outputs, enc_src_lengths)
attentive_decoder_cell = attention_wrapper.AttentionWrapper(cell=decoder_cell,
attention_mechanism=attention_mechanism)
input_vectors = tf.nn.embedding_lookup(self._tgt_w, tgt_inputs)
helper = tf.contrib.seq2seq.TrainingHelper(
inputs = input_vectors,
sequence_length = tgt_lengths)
decoder = tf.contrib.seq2seq.BasicDecoder(
cell=attentive_decoder_cell,
helper=helper,
output_layer=output_layer,
initial_state=attentive_decoder_cell.zero_state(batch_size, dtype=getdtype()))
else:
raise NotImplementedError("Unknown mode")
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
decoder = decoder,
#output_time_major=True,
impute_finished=False,
maximum_iterations=tf.reduce_max(tgt_lengths) if self._mode == 'train' else self.tgt_max_size)
return final_outputs, final_state, final_sequence_lengths |
Closing this as I think I found the issue. @ebrevdo I think it is a good idea to require TrainingHelper to take GO_SYMBOL , this way it would be more similar to inference helpers and we wouldn't get confused. Or inference helpers/decoders should make it optional. |
I am on the same problem, training results are very good with TrainingHelper, but the inference with GreedyEmbeddingHelper delivers poor performance. It produced only the GO-Symbol which defined in the GreedyEmbeddingHelper for max_iterations-times on every input. @okuchaiev How did you solve the problem? |
Can someone clarify how GO symbols are supposed to be used during training as opposed to during inference? |
@tocab , did you find a solution to your problem? |
@tocab If training example is: 'GO a b c EOS' => 'GO 1 2 3 EOS'. Then make sure that:
My guess is that you have something wrong is either 3 or 4. |
@okuchaiev , thanks for responding! Can you clarify which of 1 through 4 applying for training and which apply for validation? I'm specifically confused by the difference between 2 and 3 - when is one used versus the other? My guess is that 1, 2, and 4 apply during training, and then during validation/inference, one uses 1, 2 |
@RylanSchaeffer all 1,2,3, and 4 apply during training. Let me add some terminology to clarify: During training you have all (a), (b), and (c) ((c) is just (b) shifted by one position). |
@okuchaiev , thank you for clarifying! How did you find this out? I feel like I must have misread some critical documentation somewhere. |
@okuchaiev Thank you for your answer. But I trained my model finally in another way: I think for a), the encoder doesn't need a GO-Symbol. For b) and c) I think the EOS is only part of the output, so the decoder learn to predict a EOS after the last word ('3' in this case). With this setup I reached a good performance for training and inference and solved the problem which i had before. |
Yes, the encoder doesn't need GO & you can merge EOS & PAD (as long as sequence lengths are correct). Here are an example with a batch of 2 sentences for training: During inference: we only have (a) + GO symbol fed to GreedyEmbeddingHelper, so it's important to make sure this GO is the same as the GO in (b). Does it help clarify? I'd love to hear feedback on how to improve the tutorial so people won't make this mistake. Thanks for the discussion! |
@lmthang , thank you for clarifying! |
@lmthang quick question - the Decoder inputs do not necessarily need to end with EOS? In your example (b), the first of the two Decoder inputs does not. |
@lmthang , some suggestions regarding the tutorial:
|
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>). See tensorflow/nmt#3
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>). See tensorflow/nmt#3
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>). See tensorflow/nmt#3
Turns out the decoder output (target y) should be offset by 1 from the decoder input (exclude <GO>). See tensorflow/nmt#3
We are able to train simple seq2seq (our own code but using TF v1.2 contrib.seq2seq APIs) to perform simple tasks like string reversals (BLEU 100%).
However on De-En data our inference seems to stuck emitting one or two words. During training the model samples (using TrainingHelper) almost perfect translations. BUT inference keeps emitting same token or two per sentence. (We tried BasicDecoder with GreedyEmbeddingHelper and BeamSearchDecoder for inference).
Looking at evaluation loss on De-En (computed using BasicDecoder + GreedyEmbeddingHelper), it looks as if model really quickly overfits (eval loss shoots up and train loss goes down quickly).
When I try training using BasicDecoder + GreedyEmbeddingHelper, both training loss and eval loss are trending down (need few more hours to train though).
Also, on another (bigger and with longer sequences) toy task, test results after training using BasicDecoder + GreedyEmbeddingHelper are much better than after training using BasicDecoder + TrainingHelper. In all our experiments we use Bahdanau attention.
Hence my questions are:
@ebrevdo could you please comment on this?
The text was updated successfully, but these errors were encountered: