https://bastings.github.io/annotated_encoder_decoder/?utm_campaign=NLP%20News&utm_medium=email&utm_source=Revue%20newsletter

In [1]:
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from IPython.core.debugger import set_trace

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE=torch.device('cuda:0') # or set to 'cpu'
print("CUDA:", USE_CUDA)
print(DEVICE)

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

CUDA: True
cuda:0


In [3]:
class EncoderDecoder(nn.Module):
    """
    Standard Encoder-Decoder
    """
    def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.generator = generator
        
    def forward(self, src, trg, src_mask, trg_mask, src_lengths, trg_lengths):
        encoder_hidden, encoder_final = self.encode(src, src_lengths)
        return self.decode(encoder_hidden, encoder_final, src_mask, trg, trg_mask)
    
    def encode(self, src, src_lengths):
        return self.encoder(self.src_embed(src), src_lengths)
    
    def decode(self, encoder_hidden, encoder_final, src_mask, trg, trg_mask,
              decoder_hidden=None):
        return self.decoder(self.trg_embed(trg), encoder_hidden, encoder_final,
                           src_mask, trg_mask, hidden=decoder_hidden)
    
class Generator(nn.Module):
    """Define standard linear + softmax generation step"""
    def __init__(self, hidden_size, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
        
    def forward(self, x):
        return F.log_softmax(self.proj(x), -1)

## Encoder

Use bi-direction GRU

For efficiency, we need to support mini-batches. Sentences may have different lengths, so unroll differently. 

pack_padded_sequence - Packs a Tensor containing padded sequences of variable length
pad_packed_sequence - undoes the pack_padded_sequence

In [4]:
class Encoder(nn.Module):
    """encodes a sequence of word embeddings"""
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.rnn = nn.GRU(input_size, hidden_size, num_layers,
                         batch_first=True, bidirectional=True, dropout=dropout)
        
    def forward(self, x, lengths):
        """
        Applies a bi-directional GRU to sequence of embeddings X
        The input mini-batch x needs to be sorted by length
        x should have dimensions [batch, seq_length, input_size]
        """
        packed = pack_padded_sequence(x, lengths, batch_first=True)
        # final shape: (num_layers * num_directions, batch, hidden_size)
        # the first dimension will be a multiple of 2 when bi-directional
        # for example, with num_layers = 2, then first dimension is 4
        # the 0th and 2st rows and the final forwards and
        # the 1st and 3rd rows the final backwards 
        hidden, final = self.rnn(packed)
        # also returns the lengths
        # shape of hidden: (batch, seq_length, num_directions*hidden size)
        hidden, _ = pad_packed_sequence(hidden, batch_first=True)
        
        #get all final forwards
        fwd_final = final[0:final.size(0):2]
        #get all final backwards
        bwd_final = final[1:final.size(0):2]
        # shape of final: (num layers, batch, 2*hidden_size)
        final = torch.cat([fwd_final, bwd_final], dim=2)
        
        return hidden, final

In [5]:
encoder = Encoder(1, 5)

In [6]:
t = torch.from_numpy(np.array([[[4], [2]], [[3], [1]], [[2], [5]]])).type(torch.FloatTensor)

In [7]:
# 3 sentences of 2 words each

print(t.shape)

torch.Size([3, 2, 1])


In [8]:
output, final = encoder(t, torch.from_numpy(np.array([2,2,2])))

In [9]:
output.shape

torch.Size([3, 2, 10])

In [10]:
final.shape

torch.Size([1, 3, 10])

In [12]:
final[-1].unsqueeze(1).shape

torch.Size([3, 1, 10])

## Decoder

We will always use teacher forcing

Encoder Final is the last hidden state used to initalize first hidden state of decoder

In [14]:
class Decoder(nn.Module):
    """A conditional RNN decoder with attention."""
    
    def __init__(self, emb_size, hidden_size, attention, num_layers=1, dropout=0.5,
                 bridge=True):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention = attention
        self.dropout = dropout
        
        # will concat the previous word embedding with the context
        # where the context is the weighted sum of the encoder hidden states
        # which has a size of 2 * hidden
        self.rnn = nn.GRU(emb_size + 2*hidden_size, hidden_size, num_layers,
                         batch_first=True, dropout=dropout)
        # to initalize from the final encoder state
        # need to project because goes from bi-directional size to one direction
        self.bridge = nn.Linear(2*hidden_size, hidden_size, bias=True) if bridge else None
        self.dropout_layer = nn.Dropout(p=dropout)
        self.pre_output_layer = nn.Linear(hidden_size + 2*hidden_size + emb_size,
                                         hidden_size, bias=False)
        
    def forward_step(self, prev_embed, encoder_hidden, src_mask, proj_key, hidden):
        """Preform a single decoder step - 1 word"""
        
        # this is the hidden state for the decoder
        query = hidden[-1].unsqueeze(1) # takes the last layer. shape: [batch, 1, hidden_dimension]
        context, _ = self.attention(query=query, proj_key=proj_key,
                                            value=encoder_hidden, mask=src_mask)
        
        # process the concat of the prev_embedding and the context (attention)
        rnn_input = torch.cat([prev_embed, context], dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        
        # use the prev_embed, output from GRU, and context to get final vector of hidden_size
        # which will be used to generated probabilities across our vocabulary
        pre_output = torch.cat([prev_embed, output, context], dim=2)
        pre_output = self.dropout_layer(pre_output)
        pre_output = self.pre_output_layer(pre_output)
        
        return output, hidden, pre_output
    
    def forward(self, trg_embed, encoder_hidden, encoder_final, src_mask, trg_mask,
               hidden=None, max_len=None):
        """Unroll the encoder 1 step at a time"""
        
        if max_len is None:
            max_len = trg_mask.size(-1)
            
        if hidden is None:
            hidden = self.init_hidden(encoder_final)
            
        # pre-compute projected encoder hidden states
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency
        # projects them to hidden_dimension size
        proj_key = self.attention.key_layer(encoder_hidden)
        
        decoder_states = []
        pre_output_vectors = []
        
        for i in range(max_len):
            prev_embed = trg_embed[:, i].unsqueeze(1)
            output, hidden, pre_output = self.forward_step(prev_embed,
                                                          encoder_hidden,
                                                          src_mask, proj_key, hidden)
            decoder_states.append(output)
            pre_output_vectors.append(pre_output)
            
        decoder_states = torch.cat(decoder_states, dim=1)
        pre_output_vectors = torch.cat(pre_output_vectors, dim=1)
        return decoder_states, hidden, pre_output_vectors # [B,N,D]

    def init_hidden(self, encoder_final):
        """Returns the intial decoder state, conditioned on final encoder state"""
        if encoder_final is None:
            return None #start with zeros
        else:
            return torch.tanh(self.bridge(encoder_final))

## Attention

At every time step, decoder has **ALL** the source word hidden states. Attention allows it to learn which are the most relevant. 

The state of the decoder is represented by the hidden state.

We will use an MLP-based, additive attention with tanh activation.

**Decoder state:** is the query
**The encoder states:** the key

We add the query to all of the keys and pass them through a tanh function then project them to a single number (the **energy**).

We then mask out invalid positions and apply softmax to get probability distribution across the words.

We then take a weighted sum of the encoder hidden states, where the weights are the probabilities. This is the **context.**

In [78]:
class BahdanauAttention(nn.Module):
    """Implements Bahdanau (MLP) attention"""
    
    def __init__(self, hidden_size, key_size=None, query_size=None):
        super(BahdanauAttention, self).__init__()
        
        # assume bi-directional encoder 
        key_size = 2 * hidden_size if key_size is None else key_size
        query_size = hidden_size if query_size is None else query_size
        
        self.key_layer = nn.Linear(key_size, hidden_size, bias=False)
        self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
        self.energy_layer = nn.Linear(hidden_size, 1, bias=False)
        
        self.alphas = None
        
    def forward(self, query=None, proj_key=None, value=None, mask=None):
        #assert mask is not None, "mask is required"
        
        # project the query to hidden_size (this already done for keys)
        query = self.query_layer(query)
        
        # Calculate energies
        # pytorch broadcasts the query to all the keys with the addition
        # shape batch x seq_length x 1
        scores = self.energy_layer(torch.tanh(query + proj_key))
        # batch x 1 x seq_length
        scores = scores.squeeze(2).unsqueeze(1)
        
        # Mask out invalid positions.
        # The mask marks valid positions so we invert it using `mask & 0`.
        # the mask is broadcastable b/c it is of size seq_length
        # see: https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
        # the softmax will drop all negative infinity values
        scores.data.masked_fill_(mask == 0, -float('inf'))
        
        # turn scores into probabilities
        alphas = F.softmax(scores, dim=-1)
        self.alphas = alphas
        
        # context is weighted sum of the values (the original encoder hidden states)
        context = torch.bmm(alphas, value)
        
        # context shape: batch x 1 x key size
        # alpha shape: batch x 1 x mask length
        return context, alphas

## to do
need to figure out how masking is working and the value in attention. The should be the same dimensions for bmm to work...


In [79]:
attn = BahdanauAttention(5)
keys = attn.key_layer(output)

In [80]:
query = torch.FloatTensor([1,2,3,4,5])

In [81]:
attn(query, keys)

torch.Size([3, 2, 1])
torch.Size([3, 1, 2])


In [73]:
keys.shape

torch.Size([3, 2, 5])

In [82]:
F.softmax(torch.Tensor([0, float('-inf')]), -1)

tensor([1., 0.])