Skip to content
This repository has been archived by the owner on Dec 11, 2023. It is now read-only.

Clarifying How to Correctly Access Alignment History #36

Closed
RylanSchaeffer opened this issue Jul 29, 2017 · 9 comments
Closed

Clarifying How to Correctly Access Alignment History #36

RylanSchaeffer opened this issue Jul 29, 2017 · 9 comments

Comments

@RylanSchaeffer
Copy link

RylanSchaeffer commented Jul 29, 2017

I have a problem with the Seq2Seq library, and I'm trying to use this tutorial to find out where my bug is. My problem is that my model's alignment values are initially uniformly distributed across encoder outputs; however, the alignment values remain the same even after training, despite the fact that my model's accuracy climbs from chance (25%) to 100%. My problem is structured in a way that the only way to do well at it is for the decoder to learn to pay attention.

Copying and pasting from the earlier Github issue:

As background, both my inputs and labeled outputs at each time step are vectors of shape (4, ). I run my encoder for 500 steps i.e. inputs have shape (minibatch size, 500, 4), and my decoder runs for approximately 40-41 steps i.e. final output has shape (minibatch size, 41, 4). Each output label depends roughly on 12 sequential inputs, so for example, the first output depends on inputs 1-12, the second output depends on inputs 13-24, etc.

I don't use embeddings since doing so isn't applicable for my problem.

I reduced my model to a single layer encoder, single layer decoder to eliminate any mistake I might be making with multi-layered architectures. The encoder is a bidirectional RNN.

At the start of training, my alignment_history has roughly random uniform weights. Its shape is (41, minibatch size, 500) (although I could transpose it from time-major to batch-major). alignment_history will have values between 0.001739 and 0.002241, which makes sense - randomly initialized attention should be around 1/500 = 0.002. Additionally, my model performs at chance (25% classification accuracy).

During training, my model converges to 100% classification accuracy on both training and validation data, as shown below.

screen shot 2017-07-29 at 12 01 29 pm

The model never sees the same training data twice, so I'm 99% confident that the model isn't memorizing the training data. However, after training, the values of alignment_history effectively haven't changed; the values now look randomly chosen from between 0.00185 and 0.00219.

My code is relatively straightforward. I have a class encapsulating my model. One method instantiates a RNN cell:

@staticmethod
def _create_lstm_cell(cell_size):
    """
    Creates a RNN cell. If lstm_or_gru is True (default), create a Layer
    Normalized LSTM cell (if layer_norm is True (default); otherwise,
    create a vanilla LSTM cell. If lstm_or_gru is False, create a Gated
    Recurrent Unit cell.
    """

    if tf.flags.FLAGS.lstm_or_gru:
        if tf.flags.FLAGS.layer_norm:
            return LayerNormBasicLSTMCell(cell_size)
        else:
            return BasicLSTMCell(cell_size)
    else:
        return GRUCell(cell_size)

I have one method for building the encoder:

def _define_encoder(self):
    """
    Construct an encoder RNN using a bidirectional layer.
    """

    with tf.variable_scope('define_encoder'):

        encoder_outputs, encoder_final_states = bidirectional_dynamic_rnn(
            cell_fw=self._create_lstm_cell(ENCODER_SINGLE_DIRECTION_SIZE),
            cell_bw=self._create_lstm_cell(ENCODER_SINGLE_DIRECTION_SIZE),
            inputs=self.x,
            dtype=tf.float32,
            sequence_length=self.x_lengths,
            time_major=False  # default
        )

        # concatenate forward and backwards encoder outputs
        encoder_outputs = tf.concat(encoder_outputs, axis=-1)

        # concatenate forward and backwards cell states
        new_c = tf.concat([encoder_final_states[0].c, encoder_final_states[1].c], axis=1)
        new_h = tf.concat([encoder_final_states[0].h, encoder_final_states[1].h], axis=1)
        encoder_final_states = (LSTMStateTuple(c=new_c, h=new_h),)

    return encoder_outputs, encoder_final_states

I similarly have another method for building the decoder:

def _define_decoder(self, encoder_outputs, encoder_final_states):
    """
    Construct a decoder complete with an attention mechanism. The encoder's
    final states will be used as the decoder's initial states.
    """



    with tf.variable_scope('define_decoder'):
        # instantiate attention mechanism
        attention_mechanism = BahdanauAttention(num_units=DECODER_SIZE,
                                                memory=encoder_outputs,
                                                normalize=True)

        # wrap LSTM cell with attention mechanism
        attention_cell = AttentionWrapper(cell=self._create_lstm_cell(cell_size=DECODER_SIZE),
                                          attention_mechanism=attention_mechanism,
                                          # output_attention=False,  # doesn't seem to affect alignments
                                          alignment_history=True,
                                          attention_layer_size=DECODER_SIZE)  # arbitrarily chosen

        # create initial attention state of zeros everywhere
        decoder_initial_state = attention_cell.zero_state(batch_size=tf.flags.FLAGS.batch_size, dtype=tf.float32).clone(cell_state=encoder_final_states[0])


        # TODO: switch this out at inference time
        training_helper = TrainingHelper(inputs=self.y,  # feed in ground truth
                                         sequence_length=self.y_lengths)  # feed in sequence lengths

        decoder = BasicDecoder(cell=attention_cell,
                               helper=training_helper,
                               initial_state=decoder_initial_state
                               )

        # run decoder over input sequence
        decoder_outputs, decoder_final_states, decoder_final_sequence_lengths = dynamic_decode(
            decoder=decoder,
            maximum_iterations=41,
            impute_finished=True)

        decoder_outputs = decoder_outputs[0]
        decoder_final_states = (decoder_final_states,)

    return decoder_outputs, decoder_final_states

I use both of these methods, and then project the output of the decoder to the same dimensionality as my labels.

def _add_inference(self):
    """
    Create a Sequence-to-Sequence model using a bidirectional encoder and an
    attention mechanism-wrapped decoder.
    
    The outputs of the decoder need to be projected to a lower dimensional
    space i.e. from DECODER_SIZE to 4.
    """

    with tf.variable_scope('add_inference'):
        encoder_outputs, encoder_final_states = self._define_encoder()
        decoder_outputs, decoder_final_states = self._define_decoder(encoder_outputs, encoder_final_states)

        weights = tf.Variable(tf.truncated_normal(shape=[DECODER_SIZE, 4]))
        bias = tf.Variable(tf.truncated_normal(shape=[4]))
        logits = tf.tensordot(decoder_outputs, weights, axes=[[2], [0]]) + bias  # 2nd dimension of decoder outputs, 0th dimension of weights

    return encoder_final_states, decoder_final_states, logits

Most of my code was written before the NMT tutorial was released, so I read the code and then stepped through it, but I can't find any glaring differences. I do have a couple of additional questions.

  1. I have two hypotheses. One is that I'm incorrectly accessing my model's alignments, and the other is that I'm screwing something up in a much more significant way. Just to eliminate the first as a possibility, the correct way to access the decoder's alignments is through setting alignment_history=True in AttentionWrapper and then examining the values in decoder_final_states[0].alignment_history.stack(). Is this correct?

  2. How is the attention mechanism's num_units chosen? Is the attention mechanism's number of units required to match the number of units in the RNN cell as well as the number of units in the AttentionWrapper, or is that not necessary?

  3. I'm confused by the terminology used regarding memory, queries and keys. Memory and keys are both defined in English as "the set of source hidden states", but mathematically they're defined differently i.e. memory is W_2\overline{h}_s for Bahdanau Attention, but the keys are W_1h_t for Bahdanau Attention. My guess is that the tutorial means to say that the query h_t is converted into a key using W_1, and that key is then compared against keys generated from the encoder's hidden states i.e. W\overline{h}_s. Is this correct, or am I misunderstanding something?

@RylanSchaeffer RylanSchaeffer changed the title Clarifying How to Access Alignment History Clarifying How to Correctly Access Alignment History Jul 29, 2017
@lmthang
Copy link
Contributor

lmthang commented Jul 31, 2017

Hi Rylan,

It'd be hard for me to debug your code. 2 suggestions: keep testing your code (e.g., try to test with examples of shorter input sequences instead of those with lengths 500) or try adapting this tutorial to your problem to see if the same issue remains.

On your questions:

  1. Take a look at this https://github.com/tensorflow/nmt/blob/master/nmt/attention_model.py#L163
  2. I think the attention num_units should match that of the RNN cell https://github.com/tensorflow/nmt/blob/master/nmt/attention_model.py#L96. You can try setting it differently to see if there's an error (attention_wrapper.py does check if dimensions match).
  3. Yes, it looks like I swapped the descriptions of keys and memories. Would the updated version below sound better to you?
    Instead of having readable & writable memory, the attention mechanism presented in this tutorial is a read-only memory. Specifically, the set of source hidden states is referred to as the "memory". At each time step, we use the current target hidden state (or its transformed version, e.g., $$W_1h_t$$ in Bahdanau's scoring style) as a "query" to decide on which parts of the memory to read. Usually, the query needs to be compared with keys corresponding to individual memory slots. In the above presentation of the attention mechanism, we happen to use the same set of source hidden states (or their transformed versions, e.g., $$W\overline{h}_s$$ in Luong's scoring style or $$W_2\overline{h}_s$$ in Bahdanau's scoring style) as "keys".

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 1, 2017

@lmthang I spent yesterday running different tests, and I found that an encoder/decoder, each with LSTM cells with num_units = 2 and no attention mechanism, still converges to 100% accuracy (just more slowly).

Is it possible that by using tf.contrib.seq2seq.TrainingHelper while training, I'm somehow allowing my model to circumvent learning?

Edit: I think I'm misusing TrainingHelper. What precisely are the inputs to TrainingHelper? I'm not familiar with tf.contrib.data, so I don't know what exactly is produced by target_input = iterator.target_input on line 308 of model.py.

Edit 2: The purpose of the TrainingHelper is to pass the correct output of the previous time step to the decoder, yes? I thought that TrainingHelper automatically passes the start symbol into the decoder at the first step. Now I'm starting to think that this is incorrect.

@RylanSchaeffer
Copy link
Author

@lmthang , regarding 3, yes - that is more clear :)

@xushenkun
Copy link

I get the same problem for alignment. Though the loss is very low, the alignment is still random. I have tested with different attentions "scaled_luong" and "normed_bahdanau". Unfortunately, the alignment is not correct.
Have you solved this alignment problem, @RylanSchaeffer ?

@RylanSchaeffer
Copy link
Author

I did! I wasn't using start and stop tokens properly. What do your sequences look like?

@xushenkun
Copy link

@RylanSchaeffer, My encoder sequence is audio wave's logfbank feature with shape [batch_size, wave_len, feature_dim] and there is no start or stop tokens for it. My decoder sequence is English words with shape [batch_size, sentence_len, vocab_dim]. The start token is tf.one_hot([-1], vocab_size)=(0,0,...0,0) and stop token is tf.one_hot([vocab_size], vocab_size)=(0,0,...0,1).
I think even if the start or stop tokens are wrong, why the loss is reduced to zero with the wrong alignment?

@xushenkun
Copy link

@RylanSchaeffer, In addition, I think the loss can be reduced to zero for large epoch number with just one training sample. But the alignment can't be learned using only one training sample. Is this correct? Thanks

@RylanSchaeffer
Copy link
Author

But the alignment can't be learned using only one training sample

I can't say for certain, but yes, this is very likely correct

I think the loss can be reduced to zero for large epoch number with just one training sample

I'm not sure what you mean here. An epoch is usually referred to as a pass through the entire training dataset. If your point is that if your training dataset consists of just one sample, and you repeatedly train using a high number of epochs, then yes - your loss will approach zero. However, your validation loss will likely skyrocket because you'll be overfitting.

why the loss is reduced to zero with the wrong alignment?

I can't say for certain. Have you looked at the outputs of your decoder to determine whether the outputs match the target sequence? If not, then there's likely a problem with your loss function. One possibility is that you are incorrectly staggering how the decoder is fed. For instance, if the target sequence is [start, 1, 2, 3], you might be training your network to produce [start, 1, 2, 3] instead of [1,2,3, end]. This would reach very low loss with meaningless assignments.

@sdenton4
Copy link

A couple probable issues here which I ran into myself.

a) The TrainingHelper uses 100% teacher forcing. This means that the Decoder is given the groundtruth from the previous step instead of what it actually came up with. This makes it very difficult to interpret training loss; you'll want a separate run on test data to determine how well the model is doing.

b) Keep in mind that the decoder is fully capable of learning a language model. So if you're producing characters, it can learn to guess the next character in a sequence without knowing the input data. Again, this makes it very hard to judge model quality based on training loss.

You'll probably do better with the ScheduledOutputTrainingHelper. This lets you reduce the amount of teacher forcing, which will encourage the model to pay more attention to the actual inputs and less attention to the supplied groundtruth. If you read the 'Listen Attend and Spell' paper, they use 80% teacher forcing during training.

One trick I've been using to make the training loss /slightly/ more interpretable is to build a model with all of the input sequences zero'ed out, to get a sense of how well it's possible to do with just the amount of teacher forcing you're providing.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants