https://bastings.github.io/annotated_encoder_decoder/?utm_campaign=NLP%20News&utm_medium=email&utm_source=Revue%20newsletter

In [2]:
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from IPython.core.debugger import set_trace

In [3]:
USE_CUDA = torch.cuda.is_available()
DEVICE=torch.device('cuda:0') # or set to 'cpu'
print("CUDA:", USE_CUDA)
print(DEVICE)

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

CUDA: True
cuda:0


In [6]:
class EncoderDecoder(nn.Module):
    """
    Standard Encoder-Decoder
    """
    def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.generator = generator
        
    def forward(self, src, trg, src_mask, trg_mask, src_lengths, trg_lengths):
        encoder_hidden, encoder_final = self.encode(src, src_lengths)
        return self.decode(encoder_hidden, encoder_final, src_mask, trg, trg_mask)
    
    def encode(self, src, src_lengths):
        return self.encoder(self.src_embed(src), src_lengths)
    
    def decode(self, encoder_hidden, encoder_final, src_mask, trg, trg_mask,
              decoder_hidden=None):
        return self.decoder(self.trg_embed(trg), encoder_hidden, encoder_final,
                           src_mask, trg_mask, hidden=decoder_hidden)
    
class Generator(nn.Module):
    """Define standard linear + softmax generation step"""
    def __init__(self, hidden_size, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
        
    def forward(self, x):
        return F.log_softmax(self.proj(x), -1)

## Encoder

Use bi-direction GRU

For efficiency, we need to support mini-batches. Sentences may have different lengths, so unroll differently. 

pack_padded_sequence - Packs a Tensor containing padded sequences of variable length
pad_packed_sequence - undoes the pack_padded_sequence

In [97]:
class Encoder(nn.Module):
    """encodes a sequence of word embeddings"""
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.rnn = nn.GRU(input_size, hidden_size, num_layers,
                         batch_first=True, bidirectional=True, dropout=dropout)
        
    def forward(self, x, lengths):
        """
        Applies a bi-directional GRU to sequence of embeddings X
        The input mini-batch x needs to be sorted by length
        x should have dimensions [batch, seq_length, input_size]
        """
        packed = pack_padded_sequence(x, lengths, batch_first=True)
        # final shape: (num_layers * num_directions, batch, hidden_size)
        # the first dimension will be a multiple of 2 when bi-directional
        # for example, with num_layers = 2, then first dimension is 4
        # the 0th and 2st rows and the final forwards and
        # the 1st and 3rd rows the final backwards 
        hidden, final = self.rnn(packed)
        # also returns the lengths
        # shape of hidden: (batch, seq_length, num_directions*hidden size)
        hidden, _ = pad_packed_sequence(hidden, batch_first=True)
        
        #get all final forwards
        fwd_final = final[0:final.size(0):2]
        #get all final backwards
        bwd_final = final[1:final.size(0):2]
        # shape of final: (num layers, batch, 2*hidden_size)
        final = torch.cat([fwd_final, bwd_final], dim=2)
        
        return hidden, final

In [98]:
encoder = Encoder(1, 5)

In [99]:
t = torch.from_numpy(np.array([[[4], [2]], [[3], [1]], [[2], [5]]])).type(torch.FloatTensor)

In [100]:
# 3 sentences of 2 words each

print(t.shape)

torch.Size([3, 2, 1])


In [101]:
output, final = encoder(t, torch.from_numpy(np.array([2,2,2])))

In [102]:
output.shape

torch.Size([3, 2, 10])

## Decoder

We will always use teacher forcing

Encoder Final is the last hidden state used to initalize first hidden state of decoder

In [76]:
class Decoder(nn.Module):
    """A conditional RNN decoder with attention."""
    
    def __init__(self, emb_size, hidden_size, attention, num_layers=1, dropout=0.5,
                 bridge=True):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention = attention
        self.dropout = dropout
        
        # will concat 
        self.rnn = nn.GRU(emb_size + 2*hidden_size)