## Todo:
- Add Glove Word Vectors
- Add Xavier initialisation for all the weights of the networks
- Look up on how to structure the auxillary function?
- Update the training function to include all the parts of the model
- What is KL Annealing???

In [11]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import os

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
max_length = 100

In [18]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28))
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [6]:
class Encoder(nn.Module):
    """
    This'll be a bi-directional GRU.
    Utilises equation (1) in the paper.
    
    The hidden size is 512 as per the paper.
    """
    def __init__(self, inputSize, hiddenSize=512):
        super(Encoder, self).__init__()
        self.hiddenSize = hiddenSize
        # this embedding is a simple lookup table that stores the embeddings of a 
        # fixed dictionary and size.
        
        # This module is often used to store word embeddings and retrieve them
        # using indices. 
        # The input to the module is a list of indices, and 
        # the output is the corresponding word embeddings.
        self.embedding = nn.Embedding(inputSize, hiddenSize)
        self.gru = nn.GRU(hiddenSize, hiddenSize, bidirectional=True)
    
    def forward(self, x, hidden):
        # load the input into the embedding before doing GRU computation.
        output = self.embedding(x).view(1,1,-1)
        output, hidden = self.gru(output, hidden)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1,1, self.hiddenSize, device=device)

- Calculate a set of attention weights.
- Dot product the attention weights with the encoder output vectors.
- This result should contain information about that specific part of the input sequence, which helps the decoder choose the right words. We'll store this into a variable called attentionApplied.

In [20]:
class AttentionDecoder(nn.Module):
    """
    TODO: Add layer normalisation?
    https://arxiv.org/abs/1607.06450
    
    """
    def __init__(self, hiddenSize, outputSize, maxLength = max_length):
        """
        # dropout omitted
        """
        super(AttentionDecoder, self).__init__()
        self.hiddenSize = hiddenSize
        self.outputSize = outputSize
        self.maxLength = maxLength
        
        self.embedding = nn.Embedding(self.outputSize, self.hiddenSize)
        # self.attention is our tiny neural network that takes in the hidden weights
        # and the previous hidden weights.
        self.attention = nn.Linear(self.hiddenSize * 2, self.maxLength)
        self.attentionCombined = nn.Linear(self.hiddenSize * 2, self.hiddenSize)
        
        self.gru = nn.GRU(self.hiddenSize, self.hiddenSize)
        
        self.out = nn.Linear(self.hiddenSize, self.outputSize)
    
    def forward(self, input, hidden, previousHidden, encoderOutputs):
        embedded = self.embedding(input).view(1,1,-1)
        
        # concatenate hidden layer inputs together.
        attentionInputs =  torch.cat((embedded[0], hidden[0]), 1)
        attentionWeights = F.softmax(self.attention(attentionInputs), dim=1)
        
        attentionApplied = torch.bmm(attentionWeights.unsqueeze(0),
                                    encoderOutputs.unsqueeze(0))
        
        output = torch.cat((embedded[0], attentionApplied[0]), 1)
        output = self.attentionCombined(output).unsqueeze(0)
        
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        
        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attentionWeights

    
    def initHidden(self):
        return torch.zeros(1,1, self.hiddenSize, device=device) 

In [21]:
class Inference(nn.Module):
    """
    
    The latent size is 400 as per the paper.
    """
    def __init__(self, feature_size, class_size, latent_size=400):
        super(Inference, self).__init__()
        
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size + class_size, 400)
        self.mean = nn.Linear(400, latent_size)
        self.var = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x_forward, c, h_backward): # Q(z|x, c)
        '''
        x: (bs, feature_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([x_forward, c, h_backward], 1) # (bs, feature_size+class_size)
        h1 = self.relu(self.fc1(inputs))
        z_mu = self.mean(h1)
        z_var = self.var(h1)
        return z_mu, z_var

    def reparametrize(self, mu, logvar):
        # samples your mu, logvar to get z.
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std) + mu
        else:
            return mu

    def decode(self, z, c): # P(x|z, c)
        '''
        z: (bs, latent_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([z, c], 1) # (bs, latent_size+class_size)
        h3 = self.relu(self.fc3(inputs))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 28*28), c)
        z = self.reparametrize(mu, logvar)
        return self.decode(z, c), mu, logvar

In [22]:
class Prior(nn.Module):
    def __init__(self, feature_size, class_size):
        super(Prior, self).__init__()
        
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size + class_size, 400)
        self.mean = nn.Linear(400, latent_size)
        self.var = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, h, c): # Q(z|x, c)
        '''
        x: (bs, feature_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([h, c], 1) # (bs, feature_size+class_size)
        h1 = self.relu(self.fc1(inputs))
        z_mu = self.mean(h1)
        z_var = self.var(h1)
        return z_mu, z_var

    def reparametrize(self, mu, logvar):
        # samples your mu, logvar to get z.
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std) + mu
        else:
            return mu

    def decode(self, z, c): # P(x|z, c)
        '''
        z: (bs, latent_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([z, c], 1) # (bs, latent_size+class_size)
        h3 = self.relu(self.fc3(inputs))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 28*28), c)
        z = self.reparametrize(mu, logvar)
        return self.decode(z, c), mu, logvar

In [None]:
class Auxillary(nn.Module):
    def __init__(self, latent_size):
        self.fc1  = nn.Linear(latent_size, 400)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
    
    def forward(self, z):
        """
        The motive here is to produce an auxillary loss for our 
        training objective.
        
        We do this by Sequential Bag of Words (SBOW) as the
        auxillary objective for the proposed VAD model. 
        
        We want to predict the bag of succeeding words
        in the response using the latent variable z at each
        time step.
        """

In [19]:
teacherForcingRatio = 0.5

def train(x, 
          y, 
          encoder, 
          decoder, 
          backwards,
          inference,
          prior,
          encoderOptimiser, 
          decoderOptimiser, 
          backwardsOptimiser, 
          inferenceOptimiser,
          priorOptimiser,
          maxLength = max_length):
    
    # initialise hidden variables
    encoderHidden = encoder.initHidden()
    backwardsHidden = backwards.initHidden()
    
    # initialise gradients (IMPORTANT!)
    encoderOptimiser.zero_grad()
    decoderOptimiser.zero_grad()
    backwardsOptimiser.zero_grad()
    inferenceOptimiser.zero_grad()
    priorOptimiser.zero_grad()
    
    inputLength = x.size(0)
    targetLength = y.size(0)
    
    loss = 0
    
    # set up encoder computation
    encoderOutputs = torch.zeros(maxLength, encoder.hiddenSize, device=device)
    backwardOutputs = torch.zeros(maxLength, encoder.hiddenSize, device=device)

    for ei in range(inputLength):
        encoderOutput, encoderHidden = encoder(x[ei], encoderHidden)
        encoderOutputs[ei] = encoderOutput[0,0]
    
    for t in range(targetLength-1, 0, 1):
        # here we can also build the backwards RNN that takes in the y.
        # this backwards RNN conditions our latent variable.
        backwardOutput, backwardsHidden = backwards(y[t+1], backwardsHidden)
        backwardOutputs[t] = backwardOutput[0,0]
    
    # --------------------------
    
    # set up decoder variables
    decoderInput = torch.tensor([[SOS_token]], device=device)
    decoderHidden = encoderHidden
    
    enableTeacherForcing = False
    if random.random() < teacherForcingRatio:
        enableTeacherForcing = True
    
    """
    Here we traverse through the decoder.
    
    in this part we can also feed the backwards rnn at y.
    """
    
    if enableTeacherForcing:
        # teacher forcing: feeds the target as the next input.
        for di in range(targetLength):
            # compute the output of each decoder state
            DecoderOut = decoder(decoderInput, decoderHidden, encoderOutputs)
            (decoderOutput, decoderHidden, decoderAttention) = DecoderOut
            
            # calculate the loss
            loss += loss_function(decoderOutput, y[di])
            # feed this output to the next input
            decoderInput = y[di]
    else:
        # no techer forcing: use the predicted output as the next input.
        for di in range(targetLength):
            # compute the output of each decoder state
            DecoderOut = decoder(decoderInput, decoderHidden, encoderOutputs)
            (decoderOutput, decoderHidden, decoderAttention) = DecoderOut
            
            toPV, toPI = decoderOutput.topk(1)
            # detach from history as input
            decoderInput = toPI.squeeze().detach()
            # calculate the loss
            loss += loss_function(decoderOutput, y[di])
            # if we found `<EOS>` at this iteration, then break.
            if decoderInput.item() == EOS_token:
                break
    
    # possible because our loss_function uses gradient storing calculatioons
    loss.backward()
    
    encoderOptimiser.step()
    decoderOptimiser.step()
    
    return loss.item()/targetLength

In [None]:
def trainIters(encoder, decoder, iterations, printEvery=1000, plotEvery=100, learningRate=0.01):
    # store statistics so we can use them to 
    # show progress.
    start = time.time()
    plotLosses = []
    printLossTotal = 0
    plotLossTotal = 0
    
    # setup optimisers
    encoderOptimiser = optim.SGD(encoder.parameters(), lr=learningRate)
    decoderOptimiser = optim.SGD(decoder.parameters(), lr=learningRate)
    trainingPairs = [tensorsFromPair(random.choice(pairs)) for i in range(iterations)]
    
    for i in range(1, iterations + 1):
        # set up variables needed for training.
        trainingPair = trainingPairs[i-1]
        x, y = trainingPair[0], trainingPair[1]
        # calculate loss.
        loss = train(x, y, encoder, decoder, encoderOptimiser, decoderOptimiser)
        # increment our print and plot.
        printLossTotal += loss
        plotLossTotal += loss
        
        # print mechanism
        if i % printEvery == 0:
            printLossAvg = printLossTotal / printEvery
            # reset the print loss.
            printLossTotal = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, i / iterations),
                                         i, i / iterations * 100, printLossAvg))
        # plot mechanism
        if i % plotEvery == 0:
            plotLossAvg = plotLossTotal / plotEvery
            plotLosses.append(plotLossAvg)
            plotLossTotal = 0
    
    showPlot(plotLosses)