Skip to content

Commit

Permalink
more info about seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
pdasigi committed Oct 7, 2020
1 parent 04e8d2b commit 4c12df0
Showing 1 changed file with 277 additions and 20 deletions.
297 changes: 277 additions & 20 deletions chapters/part3/semantic-parsing-seq2seq.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ The simplest way to look at the semantic parsing problem is as a
[translation](https://www.aclweb.org/anthology/P16-1002/)
[problem](https://www.aclweb.org/anthology/P16-1004/). Instead of translating from one natural
language to another, we will translate from a natural language to a programming language. In 2020,
that means using a seq2seq model to generate a program conditioned on some input utterance.
that means using a `Seq2seq` model to generate a program conditioned on some input utterance.

We don't have a chapter on seq2seq models yet, but
We don't have a chapter on `Seq2seq` models yet, but
[here](https://nlp.stanford.edu/~johnhew/public/14-seq2seq.pdf) is a good overview of the concepts
involved. We will encode the input utterance using some encoder, then decode a sequence of tokens
in the target (programming) language.
Expand All @@ -159,13 +159,13 @@ in the target (programming) language.

<exercise id="4" title="Implementing a seq2seq model">

For this, we are literally just taking AllenNLP's existing seq2seq model and using it as-is for
For this, we are literally just taking AllenNLP's existing `Seq2seq` model and using it as-is for
semantic parsing. We'll highlight a few relevant points here, but we will defer most details to the
chapter on general seq2seq models (which isn't currently written).
chapter on general `Seq2seq` models (which isn't currently written).

## Dataset Reader

The code example below shows a simplified version of a `DatasetReader` for seq2seq data. We just
The code example below shows a simplified version of a `DatasetReader` for `Seq2seq` data. We just
have two `TextFields` in our `Instance`, one for the input tokens and one for the output tokens.
There are two important things to notice:

Expand All @@ -179,7 +179,7 @@ get around this side effect and share your embeddings, but they are more complic
ask a question in our [Discourse forum](https://discourse.allennlp.org) or open an issue on the
[guide repo](https://github.com/allenai/allennlp-guide) if you would like to see some more detail on
how to do this.
2. We're adding special tokens to our output sequence. In a seq2seq model, you typically give a
2. We're adding special tokens to our output sequence. In a `Seq2seq` model, you typically give a
special input to the model to tell it to start decoding, and it needs to have a way of signaling
that it is finished decoding. So, we modify the program tokens to include these signaling tokens.

Expand All @@ -191,35 +191,291 @@ aren't handled nicely in this way, see the [`Vocabulary` section](/reading-data#

## Model

Let us now look at the key parts of the `Seq2seq` model. A Seq2seq model typically involves two sequential modules,
like recurrent neural networks with LSTM cells, one that processes the input sequence one token at a time, and another
Let us now look at the key parts of the `Seq2seq` model. The model typically involves two sequential modules,
like recurrent neural networks with LSTM cells, one that processes the input sequence of source tokens one token at a time, and another
that outputs a predicted output sequence, one token at a time. These two modules are usually called encoder and decoder respectively.
The following is a high-level gist of how a Seq2seq model is trained (with two caveats):

We will use the [`ComposedSeq2Seq`](https://github.com/allenai/allennlp-models/blob/master/allennlp_models/generation/models/composed_seq2seq.py)
model implemented in AllenNLP. As the encoder, we will use a simple `LSTM`, and as the decoder, we will use an
[`AutoRegressiveSeqDecoder`](https://github.com/allenai/allennlp-models/blob/master/allennlp_models/generation/modules/seq_decoders/auto_regressive.py)
also with `LSTM` cells.

The following is a high-level gist of how a `Seq2seq` model is trained, with some code snippets, wherever relevant. The code snippets may not be fully
clear out of context, but should give you a general sense of what is happening.

1. We embed the source tokens and encode the sequence using a sequence encoder.

```python
def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# shape: (batch_size, max_input_sequence_length, encoder_input_dim)
embedded_input = self._source_text_embedder(source_tokens)
# shape: (batch_size, max_input_sequence_length)
source_mask = util.get_text_field_mask(source_tokens)
# shape: (batch_size, max_input_sequence_length, encoder_output_dim)
encoder_outputs = self._encoder(embedded_input, source_mask)
return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}
```

2. We initialize a sequence decoder with the final state of the encoder and use the decoder to run the following steps as many
times as we want the maximum length of the output sequence to be:
- 2a. Embed the token output by the decoder from the previous step.
- 2a. Embed the target token from the previous step.
- 2b. Decode (or more precisely, encode with the decoder) the embedded output.
- 2c. Classify the decoded output to predict an index into the target vocabulary. The predicted index indicates the output
that will be embedded in the next decoding step.

```python
def _prepare_output_projections(
self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Decode current state and last prediction to produce produce projections
into the target space, which can then be used to get probabilities of
each target token for the next step.
"""
# shape: (group_size, max_input_sequence_length, encoder_output_dim)
encoder_outputs = state["encoder_outputs"]

# shape: (group_size, max_input_sequence_length)
source_mask = state["source_mask"]

# shape: (group_size, steps_count, decoder_output_dim)
previous_steps_predictions = state.get("previous_steps_predictions")

# shape: (batch_size, 1, target_embedding_dim)
last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1)

if previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0:
# There is no previous steps, except for start vectors in `last_predictions`
# shape: (group_size, 1, target_embedding_dim)
previous_steps_predictions = last_predictions_embeddings
else:
# shape: (group_size, steps_count, target_embedding_dim)
previous_steps_predictions = torch.cat(
[previous_steps_predictions, last_predictions_embeddings], 1
)

decoder_state, decoder_output = self._decoder_net(
previous_state=state,
encoder_outputs=encoder_outputs,
source_mask=source_mask,
previous_steps_predictions=previous_steps_predictions,
)
state["previous_steps_predictions"] = previous_steps_predictions

# Update state with new decoder state, override previous state
state.update(decoder_state)

if self._decoder_net.decodes_parallel:
decoder_output = decoder_output[:, -1, :]

# shape: (group_size, num_classes)
output_projections = self._output_projection_layer(decoder_output)

return output_projections, state
```

- 2c. Classify the decoded output to predict an index into the target vocabulary. The predicted index indicates the decoder's prediction
at the current step.

```python
def take_step(
self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take a decoding step. This is called by the beam search class.
# Parameters
last_predictions : `torch.Tensor`
A tensor of shape `(group_size,)`, which gives the indices of the predictions
during the last time step.
state : `Dict[str, torch.Tensor]`
A dictionary of tensors that contain the current state information
needed to predict the next step, which includes the encoder outputs,
the source mask, and the decoder hidden state and context. Each of these
tensors has shape `(group_size, *)`, where `*` can be any other number
of dimensions.
step : `int`
The time step in beam search decoding.
# Returns
Tuple[torch.Tensor, Dict[str, torch.Tensor]]
A tuple of `(log_probabilities, updated_state)`, where `log_probabilities`
is a tensor of shape `(group_size, num_classes)` containing the predicted
log probability of each class for the next step, for each item in the group,
while `updated_state` is a dictionary of tensors containing the encoder outputs,
source mask, and updated decoder hidden state and context.
"""
# shape: (group_size, num_classes)
output_projections, state = self._prepare_output_projections(last_predictions, state)

# shape: (group_size, num_classes)
class_log_probabilities = F.log_softmax(output_projections, dim=-1)

return class_log_probabilities, state
```

3. We compute a loss by comparing the predicted sequence of tokens against the target tokens.

Now for the caveats. In practice, to make this process work, two modifications are typically made.
```python
def _get_loss(
self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.BoolTensor
) -> torch.Tensor:
"""
Compute loss.
Takes logits (unnormalized outputs from the decoder) of size (batch_size,
num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
entropy loss while taking the mask into account.
The length of `targets` is expected to be greater than that of `logits` because the
decoder does not need to compute the output corresponding to the last timestep of
`targets`. This method aligns the inputs appropriately to compute the loss.
During training, we want the logit corresponding to timestep i to be similar to the target
token from timestep i + 1. That is, the targets should be shifted by one timestep for
appropriate comparison. Consider a single example where the target has 3 words, and
padding is to 7 tokens.
The complete sequence would correspond to <S> w1 w2 w3 <E> <P> <P>
and the mask would be 1 1 1 1 1 0 0
and let the logits be l1 l2 l3 l4 l5 l6
We actually need to compare:
the sequence w1 w2 w3 <E> <P> <P>
with masks 1 1 1 1 0 0
against l1 l2 l3 l4 l5 l6
(where the input was) <S> w1 w2 w3 <E> <P>
"""
# shape: (batch_size, num_decoding_steps)
relevant_targets = targets[:, 1:].contiguous()

# shape: (batch_size, num_decoding_steps)
relevant_mask = target_mask[:, 1:].contiguous()

return util.sequence_cross_entropy_with_logits(
logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio
)

```

However, to make this process work in practice, we need two additional techniques.

### Attention

In the process described above, the only information the decoder gets about the input sequence is in the form of the final state
of the encoder after it processes the input sequence. Decoding the entire output sequence from just this information can be very difficult.
To make it easier, Seq2seq models use a so called [attention](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf)
To make it easier, `Seq2seq` models use a so called
[attention](https://www.semanticscholar.org/paper/Sequence-to-Sequence-Learning-with-Neural-Networks-Sutskever-Vinyals/cea967b59209c6be22829699f05b8b1ac4dc092d)
mechanism that lets the decoder access a summary of the outputs of the encoder after processing each input token. The summary itself is computed
based on the current state of the decoder. (TODO: pointers to code)
based on the current state of the decoder.

When using attention, Step 2b above is modified to instead decode a concatenation of the embedded output, and a summary of the encoded source sequence.

The following is the `forward` method in the
[`LstmCellDecoderNet`](https://github.com/allenai/allennlp-models/blob/master/allennlp_models/generation/modules/decoder_nets/lstm_cell.py) class
(that we're using as our decoder cell), with the option of using attention.

```python
def forward(
self,
previous_state: Dict[str, torch.Tensor],
encoder_outputs: torch.Tensor,
source_mask: torch.BoolTensor,
previous_steps_predictions: torch.Tensor,
previous_steps_mask: Optional[torch.BoolTensor] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:

decoder_hidden = previous_state["decoder_hidden"]
decoder_context = previous_state["decoder_context"]

# shape: (group_size, output_dim)
last_predictions_embedding = previous_steps_predictions[:, -1]

if self._attention:
# shape: (group_size, encoder_output_dim)
attended_input = self._prepare_attended_input(
decoder_hidden, encoder_outputs, source_mask
)

# shape: (group_size, decoder_output_dim + target_embedding_dim)
decoder_input = torch.cat((attended_input, last_predictions_embedding), -1)
else:
# shape: (group_size, target_embedding_dim)
decoder_input = last_predictions_embedding

# shape (decoder_hidden): (batch_size, decoder_output_dim)
# shape (decoder_context): (batch_size, decoder_output_dim)
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)

return (
{"decoder_hidden": decoder_hidden, "decoder_context": decoder_context},
decoder_hidden,
)
```

If using attention, we instead decode a concatenation of the embedded output, and a summary of the enoded source sequence.

### Scheduled Sampling


The process described above is how the seq2seq model is *trained*. At test time, the model gets only the source tokens, and it is expected to predict
the output tokens. This leads to an important change in the inputs to the decoder. Whereas at training time, we embed and decode the target tokens, at
test time, since we do not have access to the target tokens, we can only embed and decode the predicted tokens (that can be obtained from Step 2c).

This difference is important because it can cause the model to not generalize well --- if the decoder makes a mistake at early stages in decoding, the rest of
the predicted sequence will be adversely affected because the decoder gets inputs that it is not robust enough to handle. Forcing the model to only see the
targets at training time (often referred to as **teacher forcing**) is hence suboptimal. If we didn't use the target sequences as inputs at training time, and
only used the predicted inputs, to match the test time setup, the model may not learn anything at all due to the lack of sufficient supervision.

A solution that is the use of
[Scheduled Sampling](https://www.semanticscholar.org/paper/Scheduled-Sampling-for-Sequence-Prediction-with-Bengio-Vinyals/df137487e20ba7c6e1e2b9a1e749f2a578b5ad99)
where we randomly use the predicted tokens as inputs at training time with a specific probability. The other times, we use the target tokens as inputs.
This trick works well in practice. The scheduled sampling ratio (i.e. the probability of using predicted tokens as inputs) is a hyperparameter that needs to be tuned
for the task being modeled.

The following is a snippet from our `AutoRegressiveSeqDecoder` class that shows how probabilities of output tokens are computed with and without scheduled sampling.

```python
for timestep in range(num_decoding_steps):
if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size, steps, target_embedding_dim)
state["previous_steps_predictions"] = steps_embeddings

# shape: (batch_size, )
effective_last_prediction = last_predictions
else:
# shape: (batch_size, )
effective_last_prediction = targets[:, timestep]

if timestep == 0:
state["previous_steps_predictions"] = torch.Tensor([])
else:
# shape: (batch_size, steps, target_embedding_dim)
state["previous_steps_predictions"] = target_embedding[:, :timestep]

# shape: (batch_size, num_classes)
output_projections, state = self._prepare_output_projections(
effective_last_prediction, state
)

# list of tensors, shape: (batch_size, 1, num_classes)
step_logits.append(output_projections.unsqueeze(1))

# shape (predicted_classes): (batch_size,)
_, predicted_classes = torch.max(output_projections, 1)

# shape (predicted_classes): (batch_size,)
last_predictions = predicted_classes

# shape: (batch_size, 1, target_embedding_dim)
last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1)

# This step is required, since we want to keep up two different prediction history: gold and real
if steps_embeddings.shape[-1] == 0:
# There is no previous steps, except for start vectors in `last_predictions`
# shape: (group_size, 1, target_embedding_dim)
steps_embeddings = last_predictions_embeddings
else:
# shape: (group_size, steps_count, target_embedding_dim)
steps_embeddings = torch.cat([steps_embeddings, last_predictions_embeddings], 1)

# shape: (batch_size, num_decoding_steps, num_classes)
logits = torch.cat(step_logits, 1)
```

</exercise>

Expand All @@ -231,7 +487,7 @@ attributes of the model, dataset reader, and the trainer and locations of the tr
validation datasets. See the [chapter on configuration files](/using-config-files) of this guide
for more details. This is the configuration we'll use:

```
```json
{
"dataset_reader": {
"type": "seq2seq",
Expand Down Expand Up @@ -399,15 +655,16 @@ does not guarantee that the model produces well-formed outputs with a much large

Note that at every step while producing the output sequence, the `Seq2Seq` model chooses between tokens like `(`, `add`, `7` etc. Many of these options
are illegal according our rules of Natural Language Arithmetic. For example, after producing a `(`, the model does not need to even consider the
option of producing a number, since any valid expression will only have an operator in that position. We will explore this option in the next chapter.
option of producing a number, since any valid expression will only have an operator in that position. More generally, we can set explicit constraints on the model
to disallow illegal outputs. We will explore this direction in the next chapter.

</exercise>


<exercise id="7" title="Further reading">

In [section 3](#3) we gave links to a series of papers that used standard translation techniques at
the time to approach semantic parsing problems. There are a lot of variations on standard seq2seq
the time to approach semantic parsing problems. There are a lot of variations on standard `Seq2seq`
approaches for semantic parsing, including [recursive or two-stage
generation](https://www.aclweb.org/anthology/P18-1068/) and [grammar-based
decoding](/semantic-parsing-grammar). AllenNLP has strong support for grammar-based decoding, and
Expand Down

0 comments on commit 4c12df0

Please sign in to comment.