<a href="https://colab.research.google.com/github/Roxot/vitutorial-exercises/blob/master/LatentFactorModel_Solutions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline

import os
import re
import urllib.request
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import itertools

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In this notebook you will work with a deep generative language model that maps words from a discrete (bit-vector-valued) latent space. We will use text data (we will work on the character level) in Spanish and pytorch. 

The first section concerns data manipulation and data loading classes necessary for our implementation. You do not need to modify anything in this part of the code.

Let's first download the SIGMORPHON dataset that we will be using for this notebook: these are inflected Spanish words together with some morphosyntactic descriptors. For this notebook we will ignore the morphosyntactic descriptors.

In [0]:
url = "https://raw.githubusercontent.com/ryancotterell/sigmorphon2016/master/data/"
train_file = "spanish-task1-train"
val_file = "spanish-task1-dev"
test_file = "spanish-task1-test"

print("Downloading data files...")
if not os.path.isfile(train_file):
    urllib.request.urlretrieve(url + train_file, filename=train_file)
if not os.path.isfile(val_file):
    urllib.request.urlretrieve(url + val_file, filename=val_file)
if not os.path.isfile(test_file):
    urllib.request.urlretrieve(url + test_file, filename=test_file)
print("Download complete.")

Downloading data files...
Download complete.


# Data

In order to work with text data, we need to transform the text into something that our algorithms can work with. The first step of this process is converting words into word ids. We do this by constructing a vocabulary from the data, assigning a new word id to each new word it encounters.

In [0]:
UNK_TOKEN = "?"
PAD_TOKEN = "_"
SOW_TOKEN = ">"
EOW_TOKEN = "."

def extract_inflected_word(s):
    """
    Extracts the inflected words in the SIGMORPHON dataset.
    """
    return s.split()[-1]

class Vocabulary:
    
    def __init__(self):
        self.idx_to_char = {0: UNK_TOKEN, 1: PAD_TOKEN, 2: SOW_TOKEN, 3: EOW_TOKEN}
        self.char_to_idx = {UNK_TOKEN: 0, PAD_TOKEN: 1, SOW_TOKEN: 2, EOW_TOKEN: 3}
        self.word_freqs = {}
    
    def __getitem__(self, key):
        return self.char_to_idx[key] if key in self.char_to_idx else self.char_to_idx[UNK_TOKEN]
    
    def word(self, idx):
        return self.idx_to_char[idx]
    
    def size(self):
        return len(self.char_to_idx)
    
    @staticmethod
    def from_data(filenames):
        """
            Creates a vocabulary from a list of data files. It assumes that the data files have been
            tokenized and pre-processed beforehand.
        """
        vocab = Vocabulary()
        for filename in filenames:
            with open(filename) as f:
                for line in f:
                    
                    # Strip whitespace and the newline symbol.
                    word = extract_inflected_word(line.strip())
                    
                    # Split the words into characters and assign ids to each
                    # new character it encounters.
                    for char in list(word):
                        if char not in vocab.char_to_idx:
                            idx = len(vocab.char_to_idx)
                            vocab.char_to_idx[char] = idx
                            vocab.idx_to_char[idx] = char
                            
        return vocab

In [0]:
# Construct a vocabulary from the training and validation data.
print("Constructing vocabulary...")
vocab = Vocabulary.from_data([train_file, val_file])
print("Constructed a vocabulary of %d types" % vocab.size())

Constructing vocabulary...
Constructed a vocabulary of 37 types


In [0]:
# some examples
print('e', vocab['e'])
print('é', vocab['é'])
print('ș', vocab['ș'])  # something UNKNOWN

e 8
é 24
ș 0


We also need to load the data files into memory. We create a simple class `TextDataset` that stores the data as a list of words:

In [0]:
class TextDataset(Dataset):
    """
        A simple class that loads a list of words into memory from a text file,
        split by newlines. This does not do any memory optimisation, 
        so if your dataset is very large, you might want to use an alternative 
        class.
    """
    
    def __init__(self, text_file, max_len=30):
        self.data = []
        with open(text_file) as f:
            for line in f:
                word = extract_inflected_word(line.strip())
                if len(list(word)) <= max_len:
                    self.data.append(word)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [0]:
# Load the training, validation, and test datasets into memory.
train_dataset = TextDataset(train_file)
val_dataset = TextDataset(val_file)
test_dataset = TextDataset(test_file)

# Print some samples from the data:
print("Sample from training data: \"%s\"" % train_dataset[np.random.choice(len(train_dataset))])
print("Sample from validation data: \"%s\"" % val_dataset[np.random.choice(len(val_dataset))])
print("Sample from test data: \"%s\"" % test_dataset[np.random.choice(len(test_dataset))])

Sample from training data: "compiláramos"
Sample from validation data: "debutara"
Sample from test data: "paginabas"


Now it's time to write a function that converts a word into a list of character ids using the vocabulary we created before. This function is `create_batch` in the code cell below. This function creates a batch from a list of words, and makes sure that each word starts with a start-of-word symbol and ends with an end-of-word symbol. Because not all words are of equal length in a certain batch, words are padded with padding symbols so that they match the length of the largest word in the batch. The function returns an input batch, an output batch, a mask of 1s for words and 0s for padding symbols, and the sequence lengths of each word in the batch. The output batch is shifted by one character, to reflect the predictions that the model is expected to make. For example, for a word
\begin{align}
    \text{e s p e s e m o s}
\end{align}
the input sequence is
\begin{align}
    \text{SOW e s p e s e m o s}
\end{align}
and the output sequence is
\begin{align}
    \text{e s p e s e m o s EOW}
\end{align}

You can see the output is shifted wrt the input, that's because we will be computing a distribution for the next character in context of its prefix, and that's why we need to shift the sequence this way.


Lastly, we create an inverse function `batch_to_words` that recovers the list of words from a padded batch of character ids to use during test time.

In [0]:
def create_batch(words, vocab, device, word_dropout=0.):
    """
    Converts a list of words to a padded batch of word ids. Returns
    an input batch, an output batch shifted by one, a sequence mask over
    the input batch, and a tensor containing the sequence length of each
    batch element.
    :param words: a list of words, each a list of token ids
    :param vocab: a Vocabulary object for this dataset
    :param device: 
    :param word_dropout: rate at which we omit words from the context (input)
    :returns: a batch of padded inputs, a batch of padded outputs, mask, lengths
    """
    tok = np.array([[SOW_TOKEN] + list(w) + [EOW_TOKEN] for w in words])
    seq_lengths = [len(w)-1 for w in tok]
    max_len = max(seq_lengths)
    pad_id = vocab[PAD_TOKEN]
    pad_id_input = [
        [vocab[w[t]] if t < seq_lengths[idx] else pad_id for t in range(max_len)]
            for idx, w in enumerate(tok)]
    
    # Replace words of the input with <unk> with p = word_dropout.
    if word_dropout > 0.:
        unk_id = vocab[UNK_TOKEN]
        word_drop =  [
            [unk_id if (np.random.random() < word_dropout and t < seq_lengths[idx]) else word_ids[t] for t in range(max_len)] 
                for idx, word_ids in enumerate(pad_id_input)]
    
    # The output batch is shifted by 1.
    pad_id_output = [
        [vocab[w[t+1]] if t < seq_lengths[idx] else pad_id for t in range(max_len)]
            for idx, w in enumerate(tok)]
    
    # Convert everything to PyTorch tensors.
    batch_input = torch.tensor(pad_id_input)
    batch_output = torch.tensor(pad_id_output)
    seq_mask = (batch_input != vocab[PAD_TOKEN])
    seq_length = torch.tensor(seq_lengths)
    
    # Move all tensors to the given device.
    batch_input = batch_input.to(device)
    batch_output = batch_output.to(device)
    seq_mask = seq_mask.to(device)
    seq_length = seq_length.to(device)
    
    return batch_input, batch_output, seq_mask, seq_length


def batch_to_words(tensors, vocab: Vocabulary):
    """
    Converts a batch of word ids back to words.
    :param tensors: [B, T] word ids
    :param vocab: a Vocabulary object for this dataset
    :returns: an array of strings (each a word).
    """
    words = []
    batch_size = tensors.size(0)
    for idx in range(batch_size):
        word = [vocab.word(t.item()) for t in tensors[idx,:]]
        
        # Filter out the start-of-word and padding tokens.
        word = list(filter(lambda t: t != PAD_TOKEN and t != SOW_TOKEN, word))
        
        # Remove the end-of-word token and all tokens following it.
        if EOW_TOKEN in word:
            word = word[:word.index(EOW_TOKEN)]
            
        words.append("".join(word))
    return np.array(words)

In PyTorch the RNN functions expect inputs to be sorted from long words to shorter ones. Therefore we create a simple wrapper class for the DataLoader class that sorts words from long to short:  

In [0]:
class SortingTextDataLoader:
    """
    A wrapper for the DataLoader class that sorts a list of words by their
    lengths in descending order.
    """

    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.it = iter(dataloader)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        words = None
        for s in self.it:
            words = s
            break

        if words is None:
            self.it = iter(self.dataloader)
            raise StopIteration
        
        words = np.array(words)
        sort_keys = sorted(range(len(words)), 
                           key=lambda idx: len(list(words[idx])), 
                           reverse=True)
        sorted_words = words[sort_keys]
        return sorted_words

# Model

## Deterministic language model

In language modelling, we model a word $x = \langle x_1, \ldots, x_n \rangle$  of length $n = |x|$ as a sequence of categorical draws:

\begin{align}
X_i|x_{<i} & \sim \text{Cat}(f(x_{<i}; \theta)) 
& i = 1, \ldots, n \\
\end{align}

where we use $x_{<i}$ to denote a (possibly empty) prefix string, and thus the model makes no Markov assumption. We map from the conditioning context, the prefix $x_{<i}$, to the categorical parameters (a $v$-dimensional probability vector, where $v$ denotes the size of the vocabulary, in this case, the size of the character set) using a fixed neural network architecture whose parameters we collectively denote by $\theta$.

This assigns the following likelihood to the word
\begin{align}
    P(x|\theta) &= \prod_{i=1}^n P(x_i|x_{<i}, \theta) \\
    &= \prod_{i=1}^n \text{Cat}(x_i|f(x_{<i}; \theta))  
\end{align}
where the categorical pmf is $\text{Cat}(k|\pi) = \prod_{j=1}^v \pi_j^{[k=j]} = \pi_k$. 


Suppose we have a dataset $\mathcal D = \{x^{(1)}, \ldots, x^{(N)}\}$ containing $N$ i.i.d. observations. Then we can use the log-likelihood function 
\begin{align}
\mathcal L(\theta|\mathcal D) &= \sum_{k=1}^{N} \log P(x^{(k)}| \theta) \\
&= \sum_{k=1}^{N} \sum_{i=1}^{|x^{(k)}|} \log \text{Cat}(x^{(k)}_i|f(x^{(k)}_{<i}; \theta))
\end{align}
 to estimate $\theta$ by maximisation:
 \begin{align}
 \theta^\star = \arg\max_{\theta \in \Theta} \mathcal L(\theta|\mathcal D) ~ .
 \end{align}
 

We can use stochastic gradient-ascent to find a local optimum of $\mathcal L(\theta|\mathcal D)$, which only requires a gradient estimate:

\begin{align}
\nabla_\theta \mathcal L(\theta|\mathcal D) &= \sum_{k=1}^{|\mathcal D|} \nabla_\theta  \log P(x^{(k)}|\theta) \\ 
&= \sum_{k=1}^{|\mathcal D|} \frac{1}{N} N \nabla_\theta  \log P(x^{(k)}| \theta)  \\
&= \mathbb E_{\mathcal U(1/N)} \left[ N \nabla_\theta  \log P(x^{(K)}| \theta) \right]  \\
&\overset{\text{MC}}{\approx} \frac{N}{M} \sum_{m=1}^M \nabla_\theta  \log P(x^{(k_m)}|\theta) \\
&\text{where }K_m \sim \mathcal U(1/N)
\end{align}

This is a Monte Carlo (MC) estimate of the gradient computed on $M$ data points selected uniformly at random from $\mathcal D$.

For as long as $f$ remains differentiable wrt to its inputs and parameters, we can rely on automatic differentiation to obtain gradient estimates.


An example design for $f$ is:
\begin{align}
\mathbf x_i &= \text{emb}(x_i; \theta_{\text{emb}}) \\
\mathbf h_0 &= \mathbf 0 \\
\mathbf h_i &= \text{rnn}(\mathbf h_{i-1}, \mathbf x_{i-1}; \theta_{\text{rnn}}) \\
f(x_{<i}; \theta) &= \text{softmax}(\text{dense}_v(\mathbf h_{i};  \theta_{\text{out}}))
\end{align}
where 
* $\text{emb}$ is a fixed embedding layer with parameters $\theta_{\text{emb}}$;
* $\text{rnn}$ is a recurrent architecture with parameters $\theta_{\text{rnn}}$, e.g. an LSTM or GRU, and $\mathbf h_0$ is part of the architecture's parameters;
* $\text{dense}_v$ is a dense layer with $v$ outputs (vocabulary size) and parameters $\theta_{\text{out}}$.



In what follows we show how to extend this model with a continuous latent word embedding.

## Deep generative language model

We want to model a word $x$ as a draw from the marginal of deep generative model $P(z, x|\theta, \alpha) = P(z|\alpha)P(x|z, \theta)$. 


### Generative model

The generative story is:
\begin{align}
    Z_k & \sim \text{Bernoulli}(\alpha_k) & k=1,\ldots, K \\
    X_i | z, x_{<i} &\sim \text{Cat}(f(z, x_{<i}; \theta)) & i=1, \ldots, n
\end{align}
where $z \in \mathbb R^K$ and  we impose a product of independent Bernoulli distributions prior. Other choices of prior can induce interesting properties in latent space, for example, the Bernoullis could be correlated, however, in this notebook, we use independent distributions. 


**About the prior parameter** The parameter of the $k$th Bernoulli distribution is the probability that the $k$th bit in $z$ is set to $1$, and therefore, if we have reasons to believe some bits are more frequent than others (for example, because we expect some bits to capture verb attributes and others to capture noun attributes, and we know nouns are more frequent than verbs) we may be able to have a good guess at $\alpha_k$ for different $k$, otherwise, we may simply say that bits are about as likely to be on or off a priori, thus setting $\alpha_k = 0.5$ for every $k$. In this lab, we will treat the prior parameter ($\alpha$) as *fixed*.

**Architecture** It is easy to design $f$ by a simple modification of the deterministic design shown before:
\begin{align}
\mathbf x_i &= \text{emb}(x_i; \theta_{\text{emb}}) \\
\mathbf h_0 &= \tanh(\text{dense}(z; \theta_{\text{init}})) \\
\mathbf h_i &= \text{rnn}(\mathbf h_{i-1}, \mathbf x_{i-1}; \theta_{\text{rnn}}) \\
f(x_{<i}; \theta) &= \text{softmax}(\text{dense}_v(\mathbf h_{i};  \theta_{\text{out}}))
\end{align}
where we just initialise the recurrent cell using $z$. Note we could also use $z$ in other places, for example, as additional input to every update of the recurrent cell $\mathbf h_i = \text{rnn}(\mathbf h_{i-1}, [\mathbf x_{i-1}, z])$. This is an architecture choice which like many others can only be judged empirically or on the basis of practical convenience.



### Parameter estimation

The marginal likelihood, necessary for parameter estimation, is now no longer tractable:
\begin{align}
P(x|\theta, \alpha) &= \sum_{z \in \{0,1\}^K} P(z|\alpha)P(x|z, \theta) \\
&= \sum_{z \in \{0,1\}^K} \prod_{k=1}^K \text{Bernoulli}(z_k|\alpha_k)\prod_{i=1}^n \text{Cat}(x_i|f(z,x_{<i}; \theta) ) 
\end{align}
the intractability is clear as there is an exponential number of assignments to $z$, namely, $2^K$.

We turn to variational inference and derive a lowerbound $\mathcal E(\theta, \lambda|\mathcal D)$ on the log-likelihood function

\begin{align}
    \mathcal E(\theta, \lambda|\mathcal D) &= \sum_{s=1}^{|\mathcal D|} \mathcal E_s(\theta, \lambda|x^{(s)}) 
\end{align}

which for a single datapoint $x$ is
\begin{align}
    \mathcal E(\theta, \lambda|x) &= \mathbb{E}_{Q(z|x, \lambda)}\left[\log P(x|z, \theta)\right] - \text{KL}\left(Q(z|x, \lambda)||P(z|\alpha)\right)\\
\end{align}
where we have introduce an independently parameterised auxiliary distribution $Q(z|x, \lambda)$. The distribution $Q$ which maximises this *evidence lowerbound* (ELBO) is also the distribution that minimises 
\begin{align}
\text{KL}(Q(z|x, \lambda)||P(z|x, \theta, \alpha)) = \mathbb E_{Q(z|x, \lambda)}\left[\log  \frac{Q(z|x, \lambda)}{P(z|x, \theta, \alpha)}\right]
\end{align}
 where $P(z|x, \theta, \alpha) = \frac{P(x, z|\theta, \alpha)}{P(x|\theta, \alpha)}$ is our intractable true posterior. For that reason, we think of $Q(z|x, \lambda)$ as an *approximate posterior*. 
 
 The approximate posterior is an independent model of the latent variable given the data, for that reason we also call it an *inference model*. 
 In this notebook, our inference model will be a product of independent Bernoulli distributions, to make sure that we cover the sample space of our latent variable. We will leave at the end of the notebook as an optional exercise to model correlations (thus achieving *structured* inference, rather than mean field inference). Such mean field (MF) approximation takes $K$ Bernoulli variational factors whose parameters we predict with a neural network:
 
\begin{align}
    Q(z|x, \lambda) &= \prod_{k=1}^K \text{Bernoulli}(z_k|\beta_k(x; \lambda))
\end{align}
 
Note we compute a *fixed* number, namely, $K$, of Bernoulli parameters. This can be done with a neural network that outputs $K$ values and employs a sigmoid activation for the outputs.
 
 
For this choice, the KL term in the ELBO is tractable:

\begin{align}
\text{KL}\left(Q(z|x, \lambda)||P(z|\alpha)\right) &= \sum_{k=1}^K \text{KL}\left(Q(z_k|x, \lambda)||P(z_k|\alpha_k)\right) \\
&= \sum_{k=1}^K \text{KL}\left(\text{Bernoulli}(\beta_k(x;\lambda))|| \text{Bernoulli}(\alpha_k)\right) \\
&= \sum_{k=1}^K \beta_k(x;\lambda) \log \frac{\beta_k(x;\lambda)}{\alpha_k} + (1-\beta_k(x;\lambda)) \log \frac{1-\beta_k(x;\lambda)}{1-\alpha_k}
\end{align}


 
Here's an example design for our inference model:

\begin{align}
\mathbf x_i &= \text{emb}(x_i; \lambda_{\text{emb}}) \\
\mathbf f_i &= \text{rnn}(\mathbf f_{i-1}, \mathbf x_{i}; \lambda_{\text{fwd}}) \\
\mathbf b_i &= \text{rnn}(\mathbf b_{i+1}, \mathbf x_{i}; \lambda_{\text{bwd}}) \\
\mathbf h &= \text{dense}([\mathbf f_{n}, \mathbf b_1]; \lambda_{\text{hid}}) \\
\beta(x; \lambda) &= \text{sigmoid}(\text{dense}_K(\mathbf h; \lambda_{\text{out}}))
\end{align}

where we use the $\text{sigmoid}$ activation to make sure our probabilities are independently set between $0$ and $1$. 
 
Because we have neural networks compute the Bernoulli variational factors for us, we call this *amortised* mean field inference.



### Gradient estimation

We have to obtain gradients of the ELBO with respect to $\theta$ (generative model) and $\lambda$ (inference model). Recall we will leave $\alpha$ fixed.

For the **generative model**

\begin{align}
\nabla_\theta \mathcal E(\theta, \lambda|x)  &=\nabla_\theta\sum_{z} Q(z|x, \lambda)\log P(x|z,\theta) - \underbrace{\nabla_\theta \sum_{k=1}^K \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))}_{\color{blue}{0}}  \\
&=\sum_{z} Q(z|x, \lambda)\nabla_\theta\log P(x|z,\theta) \\
&= \mathbb E_{Q(z|x, \lambda)}\left[\nabla_\theta\log P(x|z,\theta) \right] \\
&\overset{\text{MC}}{\approx} \frac{1}{S} \sum_{s=1}^S \nabla_\theta \log P(x|z^{(s)}, \theta) 
\end{align}
where $z^{(s)} \sim Q(z|x,\lambda)$.
Note there is no difficulty in obtaining gradient estimates precisely because the samples come from the inference model and therefore do not interfere with backpropagation for updates to $\theta$.

For the **inference model** the story is less straightforward, and we have to use the *score function estimator* (a.k.a. REINFORCE):

\begin{align}
\nabla_\lambda \mathcal E(\theta, \lambda|x)  &=\nabla_\lambda\sum_{z} Q(z|x, \lambda)\log P(x|z,\theta) - \nabla_\lambda \underbrace{\sum_{k=1}^K \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))}_{ \color{blue}{\text{tractable} }}  \\
&=\sum_{z} \nabla_\lambda Q(z|x, \lambda)\log P(x|z,\theta) - \sum_{k=1}^K \nabla_\lambda \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))   \\
&=\sum_{z}  \underbrace{Q(z|x, \lambda) \nabla_\lambda \log Q(z|x, \lambda)}_{\nabla_\lambda Q(z|x, \lambda)} \log P(x|z,\theta) - \sum_{k=1}^K \nabla_\lambda \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))   \\
&= \mathbb E_{Q(z|x, \lambda)}\left[ \log P(x|z,\theta) \nabla_\lambda \log Q(z|x, \lambda) \right] - \sum_{k=1}^K \nabla_\lambda \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))   \\
&\overset{\text{MC}}{\approx} \left(\frac{1}{S} \sum_{s=1}^S  \log P(x|z^{(s)}, \theta) \nabla_\lambda \log Q(z^{(s)}|x, \lambda)  \right) - \sum_{k=1}^K \nabla_\lambda \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))  
\end{align}

where $z^{(s)} \sim Q(z|x,\lambda)$.



## Implementation

Let's implement the model and the loss (negative ELBO). We work with the notion of a *surrogate loss*, that is, a computation node whose gradients wrt to parameters are equivalent to the gradients we need.

For a given sample $z \sim Q(z|x, \lambda)$, the following is a single-sample surrogate loss:

\begin{align}
\mathcal S(\theta, \lambda|x) = \log P(x|z, \theta) + \color{red}{\text{detach}(\log P(x|z, \theta) )}\log Q(z|x, \lambda) - \sum_{k=1}^K \text{KL}(Q(z_k|x, \lambda) || P(z_k|\alpha_k))
\end{align}

Check the documentation of pytorch's `detach` method.

Show that it's gradients wrt $\theta$ and $\lambda$ are exactly what we need:


\begin{align}
\nabla_\theta \mathcal S(\theta, \lambda|x) = \color{red}{?}
\end{align}

\begin{align}
\nabla_\lambda \mathcal S(\theta, \lambda|x) = \color{red}{?}
\end{align}

**Solution**

\begin{align}
\nabla_\theta \mathcal S(\theta, \lambda|x) = \nabla_\theta \log P(x|z, \theta) + 0
\end{align}

\begin{align}
\nabla_\lambda \mathcal S(\theta, \lambda|x) &= 0 + \underbrace{\log Q(z|x, \lambda)\nabla_\lambda \log P(x|z, \theta)  + \log P(x|z, \theta) \nabla_\lambda \log Q(z|x, \lambda)}_{\text{chain rule}} \\ 
&= 0+ 0 + \log P(x|z, \theta) \nabla_\lambda \log Q(z|x, \lambda)
\end{align}

Let's now turn to the actual implementation in pytorch of the inference model as well as the generative model. 

Here and there we will provide helper code for you.

In [0]:
def bernoulli_log_probs_from_logits(logits):
    """
    Let p be the Bernoulli parameter and q = 1 - p.
    This function is a stable computation of p and q from logit = log(p/q).
    :param logit: log (p/q)
    :return: log_p, log_q
    """
    return - F.softplus(-logits), - F.softplus(logits)

We start with the implementation of a product of Bernoulli distributions where the parameters are *given* at construction time. That is, for some vector $b_1, \ldots, b_K$ we have
\begin{equation}
    Z_k \sim \text{Bernoulli}(b_k)
\end{equation}
and thus the joint probability of $z_1, \ldots, z_K$ is given by $\prod_{k=1}^K \text{Bernoulli}(z_k|b_k)$.

In [0]:
class ProductOfBernoullis:
    """
    This is class models a product of independent Bernoulli distributions.
    
    Each product of Bernoulli is defined by a D-dimensional vector of logits
    for each independent Bernoulli variable.
    """
    
    def __init__(self, logits):
        """
        :param p: a tensor of D Bernoulli parameters (logits) for each batch element. [B, D]
        """
        self.logits = logits
        self.probs = torch.sigmoid(self.logits)
        self.log_probs_1, self.log_probs_0 = bernoulli_log_probs_from_logits(logits)
        
    def mean(self):
        """For Bernoulli variables this is the probability of each Bernoulli being 1."""
        return self.probs
    
    def std(self):
        """For Bernoulli variables this is p*(1-p) where p
        is the probability of the Bernoulli being 1"""
        return self.probs * (1.0 - self.probs)
    
    def sample(self):
        """
        Returns a sample with the shape of the Bernoulli parameter. # [B, D]
        """
        u = torch.rand_like(self.probs) # uniform random draws
        sample = (u < self.probs).byte() # interpret as 0s and 1s
        return sample.float() # we use float for consistency with how pytorch implementations usually work
        
    def log_prob(self, x):
        return torch.where(x == 1, self.log_probs_1, self.log_probs_0).sum(dim=1)
        
    def log_pmf(self, x):
        """
        Assess the log probability mass of x.
        
        :param x: a tensor of Bernoulli samples (same shape as the Bernoulli parameter) [B, D]
        :returns:  tensor of log probabilitie densities
        """
        return torch.log(torch.where(x == 1, self.probs, 1 - self.probs)).sum(dim=1)
    
    def unstable_kl(self, other: 'Bernoulli'):
        """
        The straightforward implementation of the KL between two Bernoullis.
        This implementation is unstable, a stable implementation is provided in
        ProductOfBernoullis.kl(self, q)
        
        :returns: a tensor of KL values with the same shape as the parameters of self.
        """
        
        t1 = self.probs * (torch.log(self.probs) - torch.log(other.probs))
        t2 = (1 - self.probs) * (torch.log(1.0 - self.probs) - torch.log(1.0 - other.probs))
        return (t1 + t2).sum(dim=1)

    def kl(self, other: 'Bernoulli'):
        """
        A stable implementation of the KL divergence between two Bernoulli variables.
        
        :returns: a tensor of KL values with the same shape as the parameters of self.
        """
        t1 = self.probs * (self.log_probs_1 - other.log_probs_1)
        t2 = (1-self.probs) * (self.log_probs_0 - other.log_probs_0)
        return (t1 + t2).sum(dim=1)


Then we should implement the inference model $Q(z | x, \lambda)$, that is, a module that uses a neural network to map from a data point $x$ to the parameters of a product of Bernoullis.

You might want to consult the documentation of 
* `torch.nn.Embedding`
* `torch.nn.LSTM`
* `torch.nn.Linear`
* and of our own `ProductOfBernoullis` distribution (see above).

In [0]:
class InferenceModel(nn.Module):

    def __init__(self, vocab_size, embedder, hidden_size,
                 latent_size, pad_idx, bidirectional=False):
        """
        Implement the layers in the inference model.
        
        :param vocab_size: size of the vocabulary of the language
        :param embedder: embedding layer
        :param hidden_size: size of recurrent cell
        :param latent_size: size K of the latent variable
        :param pad_idx: id of the -PAD- token
        :param bidirectional: whether we condition on x via a bidirectional or 
          unidirectional encoder          
        """
        super().__init__()  # pytorch modules should always start with this
        pass
        # Construct your NN blocks here
        #  and make sure every block is an attribute of self
        #  or they won't get initialised properly
        #  for example, self.my_linear_layer = torch.nn.Linear(...)

    def forward(self, x, seq_mask, seq_len) -> ProductOfBernoullis:
        """
        Return an inference product of Bernoullis per instance in the mini-batch
        :param x: words [B, T] as token ids
        :param seq_mask: indicates valid positions vs padding positions [B, T]
        :param seq_len: the length of the sequences [B]
        :return: a collection of B ProductOfBernoullis approximate posterior, 
            each a distribution over K-dimensional bit vectors
        """
        pass

In [0]:
# SOLUTION
class InferenceModel(nn.Module):

    def __init__(self, vocab_size, embedder, hidden_size,
                 latent_size, pad_idx, bidirectional=False):
        """
        :param vocab_size: size of the vocabulary of the language
        :param embedder: embedding layer
        :param hidden_size: size of recurrent cell
        :param latent_size: size K of the latent variable
        :param pad_idx: id of the -PAD- token
        :param bidirectional: whether we condition on x via a bidirectional or 
          unidirectional encoder          
        """
        super().__init__()
        self.bidirectional = bidirectional
        
        # We borrow the embedder from the generative model, but we don't
        # want tobackpropagate through it for the inference model. So we
        # need to make sure to call detach on the embeddings later.
        self.embedder = embedder
        emb_size = embedder.embedding_dim
        
        # Create a (bidirectional) LSTM to encode x.
        self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True, 
                            bidirectional=bidirectional)
        
        # The output of the LSTM doubles if we use a bidirectional encoder.
        encoding_size = hidden_size * 2 if bidirectional else hidden_size
        
        # We can let features interact once more
        self.combination_layer = nn.Linear(encoding_size, 2 * latent_size)
        
        # Create an affine layers to project the encoder final state to
        # the logits of the independent Bernoullis that
        # we are predicting.
        self.logits_layer = nn.Linear(2 * latent_size, latent_size)

    def forward(self, x, seq_mask, seq_len) -> ProductOfBernoullis:
        
        # Compute word embeddings and detach them so that no gradients
        # from the infererence model flow through. That's done because
        # this embedding layer was borrowed from the generative model
        # thus its parameters a part of the set \theta
        x_embed = self.embedder(x).detach()
        # Alternatively, we could have construct an independent embedding layer
        # for the inference net, then its parameters would be part of the set
        # \lambda and we would allow updates 
        
        # Encode the sentence using the LSTM.
        hidden = None        
        packed_seq = pack_padded_sequence(x_embed, seq_len, batch_first=True)
        _, final = self.lstm(packed_seq, hidden)        

        # Take the final output h_T from the LSTM, concatenate the forward
        # and backward directions for the bidirectional case.
        h_T = final[0]
        if self.bidirectional:
            h_T_fwd = h_T[0]
            h_T_bwd = h_T[1]
            h_T = torch.cat([h_T_fwd, h_T_bwd], dim=-1)
        
        # We make one more transformation 
        #  this allows a few more interactions between features
        #  and if we have bidirectional features then the two 
        #  directions also interact
        h_T = torch.tanh(self.combination_layer(h_T))

        # Compute the mean and sigma of the diagonal Gaussian distribution.
        # Use a softplus activation for the standard deviation to ensure it's
        # positive.
        logits = self.logits_layer(h_T)
        
        # Return the inferred Gaussian distribution q(z|x).
        qz = ProductOfBernoullis(logits)
        return qz

In [0]:
# tests for inference model
pad_idx = vocab.char_to_idx[PAD_TOKEN]

dummy_inference_model = InferenceModel(
    vocab_size=vocab.size(),
    embedder=nn.Embedding(vocab.size(), 64, padding_idx=pad_idx),
    hidden_size=128, latent_size=16, pad_idx=pad_idx, bidirectional=True
).to(device=device)
dummy_batch_size = 32
dummy_dataloader = SortingTextDataLoader(DataLoader(train_dataset, batch_size=dummy_batch_size))
dummy_words = next(dummy_dataloader)

x_in, _, seq_mask, seq_len = create_batch(dummy_words, vocab, device)

q_z_given_x = dummy_inference_model.forward(x_in, seq_mask, seq_len)
# TODO make some assertions

Then we should implement the generative latent factor model. The decoder is a sequence of correlated Categorical draws that condition on a latent factor assignment. 

We will be parameterising categorical distributions, so you might want to check the documentation of `torch.distributions.categorical.Categorical`.


In [0]:
from torch.distributions import Categorical

class LatentFactorModel(nn.Module):
    
    def __init__(self, vocab_size, emb_size, hidden_size, latent_size,
                 pad_idx, dropout=0.):
        """
        :param vocab_size: size of the vocabulary of the language
        :param emb_size: dimensionality of embeddings
        :param hidden_size: dimensionality of recurrent cell
        :param latent_size: this is D the dimensionality of the latent variable z
        :param pad_idx: the id reserved to the -PAD- token
        :param dropout: a dropout rate (you can ignore this for now)
        """
        super().__init__()
        # Construct your NN blocks here, 
        #  remember to assign them to attributes of self
        pass

    def init_hidden(self, z):
        """
        Returns the hidden state of the LSTM initialized with a projection of a given z.
        :param z: [B, K]
        :returns: [num_layers, B, H] hidden state, [num_layers, B, H] cell state            
        """
        pass
    
    def step(self, prev_x, z, hidden):
        """
        Performs a single LSTM step for a given previous word and hidden state.
        Returns the unnormalized log probabilities (logits) over the vocabulary 
        for this time step. 
        
        :param prev_x: [B, 1] id of the previous token
        :param z: [B, K] latent variable
        :param hidden:  hidden ([num_layers, B, H] state, [num_layers, B, H] cell)
        :returns: [B, V] logits, ([num_layers, B, H] updated state, [num_layers, B, H] updated cell)
        """
        pass
        
    def forward(self, x, z) -> Categorical:
        """
        Performs an entire forward pass given a sequence of words x and a z.
        This returns a collection of [B, T] categorical distributions, each 
            with support over V events.

        :param x: [B, T] token ids 
        :param z: [B, K] a latent sample
        :returns: Categorical object with shape [B,T,V]
        """
        hidden = self.init_hidden(z)
        outputs = []
        for t in range(x.size(1)):
            # [B, 1]
            prev_x = x[:, t].unsqueeze(-1)
            # logits: [B, V]
            logits, hidden = self.step(prev_x, z, hidden)
            outputs.append(logits)
        outputs = torch.cat(outputs, dim=1)
        return Categorical(logits=outputs)
        
    def loss(self, output_distributions, observations, pz, qz, free_nats=0., evaluation=False):
        """
        Computes the terms in the loss (negative ELBO) given the 
            output Categorical distributions, observations,
        the prior distribution p(z), and the approximate posterior distribution q(z|x).
        
        If free_nats is nonzero it will clamp the KL divergence between the posterior
        and prior to that value, preventing gradient propagation via the KL if it's
        below that value. 
        
        If evaluation is set to true, the loss will be summed instead
        of averaged over the batch. 
        
        Returns the (surrogate) loss, the ELBO, and the KL. 
        
        :returns: 
            surrogate loss (scalar),
            ELBO (scalar), 
            KL (scalar)
        """
        pass

In [0]:
# SOLUTION
class LatentFactorModel(nn.Module):
    
    def __init__(self, vocab_size, emb_size, hidden_size, latent_size,
                 pad_idx, dropout=0.):
        """
        :param vocab_size: size of the vocabulary of the language
        :param emb_size: dimensionality of embeddings
        :param hidden_size: dimensionality of recurrent cell
        :param latent_size: this is D the dimensionality of the latent variable z
        :param pad_idx: the id reserved to the -PAD- token
        :param dropout: a dropout rate
        """
        super().__init__()
        self.pad_idx = pad_idx
        self.embedder = nn.Embedding(vocab_size, emb_size,
                                     padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True)
        self.bridge = nn.Linear(latent_size, hidden_size)
        self.projection = nn.Linear(hidden_size, vocab_size, bias=False)
        self.dropout_layer = nn.Dropout(p=dropout)
    
    def init_hidden(self, z):
        """
        Returns the hidden state of the LSTM initialized with a projection of a given z.
        :param z: [B, K]
        :returns: [num_layers, B, H] hidden state, [num_layers, B, H] cell state
            
        """
        h = self.bridge(z).unsqueeze(0)
        c = self.bridge(z).unsqueeze(0)
        return (h, c)
    
    def step(self, prev_x, z, hidden):
        """
        Performs a single LSTM step for a given previous word and hidden state.
        Returns the unnormalized probabilities over the vocabulary for this time step. 
        :param prev_x: [B, 1] id of the previous token
        :param z: [B, K] latent variable
        :param hidden:  hidden ([num_layers, B, H] state, [num_layers, B, H] cell)

        :returns: [B, V] logits, ([num_layers, B, H] updated state, [num_layers, B, H] updated cell)
        """
        # [B, E]
        x_embed = self.dropout_layer(self.embedder(prev_x))
        # output: [B, H]
        output, hidden = self.lstm(x_embed, hidden)
        # [B, V]
        logits = self.projection(self.dropout_layer(output))
        return logits, hidden
        
    def forward(self, x, z) -> Categorical:
        """
        Performs an entire forward pass given a sequence of words x and a z.
        This returns a collection of [B, T] categorical distributions, each 
            with support over V events.

        :param x: [B, T] token ids 
        :param z: [B, K] a latent sample
        :returns: Categorical object with shape [B,T,V]
        """
        hidden = self.init_hidden(z)
        outputs = []
        for t in range(x.size(1)):
            # [B, 1]
            prev_x = x[:, t].unsqueeze(-1)
            # logits: [B, V]
            logits, hidden = self.step(prev_x, z, hidden)
            outputs.append(logits)
        outputs = torch.cat(outputs, dim=1)
        return Categorical(logits=outputs)
        
    def loss(self, output_distributions, observations, pz, qz, z, free_nats=0., evaluation=False):
        """
        Computes the terms in the loss (negative ELBO) given the 
            output Categorical distributions, observations,
        the prior distribution p(z), and the approximate posterior distribution q(z|x).
        
        If free_nats is nonzero it will clamp the KL divergence between the posterior
        and prior to that value, preventing gradient propagation via the KL if it's
        below that value. 
        
        If evaluation is set to true, the loss will be summed instead
        of averaged over the batch. 
        
        Returns the (surrogate) loss, the ELBO, and the KL. 
        
        :returns: 
            surrogate loss (scalar),
            ELBO (scalar), 
            KL (scalar)
        """
        
        # [B, T]
        log_prob = output_distributions.log_prob(observations)
        mask = (observations != self.pad_idx)
        log_prob = torch.where(mask, log_prob, torch.zeros_like(log_prob))
        # [B]
        log_prob = log_prob.sum(-1)

        # [B]
        surrogate = log_prob.detach() * qz.log_prob(z)
        loss = - (log_prob + surrogate) # TODO baselines
                
        # Compute the KL divergence and clamp to at least the given amount of free nats.
        KL = qz.kl(pz)
        KL = torch.clamp(KL, min=free_nats)

        # Compute an ELBO estimate
        ELBO = (log_prob - KL)
        
        # For evaluation return the sum of individual components, for
        # training return the mean of those components.
        if evaluation:
            return (loss.sum(), ELBO.sum(), KL.sum())
        else:
            return (loss.mean(), ELBO.mean(), KL.mean())

The code below is used to assess the model and also investigate what it learned. We implemented it for you, so that you can focus on the VAE part. It's useful however to learn from this example: we do interesting things like computing perplexity and sampling novel words!

# Evaluation metrics

During training we'd like to keep track of some evaluation metrics on the validation data in order to keep track of how our model is doing and to perform early stopping. One simple metric we can compute is the ELBO on all the validation or test data using a single sample from the approximate posterior $Q(z|x, \lambda)$:

In [0]:
def eval_elbo(model, inference_model, eval_dataset, vocab, device, batch_size=128):
    """
    Computes a single sample estimate of the ELBO on a given dataset.
    This returns both the average ELBO and the average KL (for inspection).
    """
    dl = DataLoader(eval_dataset, batch_size=batch_size)
    sorted_dl = SortingTextDataLoader(dl)
    
    # Make sure the model is in evaluation mode (i.e. disable dropout).
    model.eval()
            
    total_ELBO = 0.
    total_KL = 0.
    num_words = 0
        
    # We don't need to compute gradients for this.
    with torch.no_grad():
        for words in sorted_dl:    
            x_in, x_out, seq_mask, seq_len = create_batch(words, vocab, device)
            
            # Infer the approximate posterior and construct the prior.
            qz = inference_model(x_in, seq_mask, seq_len)
            pz = ProductOfBernoullis(torch.ones_like(qz.probs) * 0.5)
            
            # Compute the unnormalized probabilities using a single sample from the
            # approximate posterior.
            z = qz.sample()
            # Compute distributions X_i|z, x_{<i}
            px_z = model(x_in, z)
            
            # Compute the reconstruction loss and KL divergence.
            loss, ELBO, KL = model.loss(px_z, x_out, pz, qz, z,
                                                 free_nats=0.,
                                                 evaluation=True)
            total_ELBO += ELBO
            total_KL += KL
            num_words += x_in.size(0)

    # Return the average reconstruction loss and KL.
    avg_ELBO = total_ELBO / num_words
    avg_KL = total_KL / num_words
    return avg_ELBO, avg_KL

In [0]:
dummy_lm = LatentFactorModel(
    vocab.size(), emb_size=64, hidden_size=128, 
    latent_size=16, pad_idx=pad_idx).to(device=device)

!head -n 128 {val_file} > ./dummy_dataset
dummy_data = TextDataset('./dummy_dataset')
dummy_ELBO, dummy_kl = eval_elbo(dummy_lm, dummy_inference_model,
                                     dummy_data, vocab, device)
print(dummy_ELBO, dummy_kl)
assert dummy_kl.item() > 0

tensor(-37.6747, device='cuda:0') tensor(0.5302, device='cuda:0')



A common metric to evaluate language models is the perplexity per word. The perplexity per word for a dataset is defined as:

\begin{align}
    \text{ppl}(\mathcal{D}|\theta, \lambda) = \exp\left(-\frac{1}{\sum_{k=1}^{|\mathcal D|} n^{(k)}} \sum_{k=1}^{|\mathcal{D}|} \log P(x^{(k)}|\theta, \lambda)\right) 
\end{align}

where $n^{(k)} = |x^{(k)}|$ is the number of tokens in a word and $P(x^{(k)}|\theta, \lambda)$ is the probability that our model assigns to the datapoint $x^{(k)}$. In order to compute $\log P(x|\theta, \lambda)$ for our model we need to evaluate the marginal:

\begin{align}
    P(x|\theta, \lambda) = \sum_{z \in \{0, 1\}^K} P(x|z,\theta) P(z|\alpha)
\end{align}

As this is summation  cannot be computed in a reasonable amount of time (due to exponential complexity), we have two options: we can use the earlier derived lower-bound on the log-likelihood, which will give us an upper-bound on the perplexity, or we can make an importance sampling estimate using our approximate posterior distribution. The importance sampling (IS) estimate can be done as:

\begin{align}
\hat P(x|\theta, \lambda) &\overset{\text{IS}}{\approx} \frac{1}{S} \sum_{s=1}^{S} \frac{P(z^{(s)}|\alpha)P(x|z^{(s)}, \theta)}{Q(z^{(s)}|x)} & \text{where }z^{(s)} \sim Q(z|x)
\end{align}

where $S$ is the number of samples.

Then our perplexity becomes:
\begin{align}
    &\frac{1}{\sum_{k=1}^{|\mathcal D|} n^{(k)}}  \sum_{k=1}^{|\mathcal D|} \log P(x^{(k)}|\theta) \\
    &\approx \frac{1}{\sum_{k=1}^{|\mathcal D|} n^{(k)}}  \sum_{k=1}^{|\mathcal D|} \log \frac{1}{S} \sum_{s=1}^{S} \frac{P(z^{(s)}|\alpha)P(x^{(k)}|z^{(s)}, \theta)}{Q(z^{(s)}|x^{(k)})} \\
\end{align}

We define the function `eval_perplexity` below that implements this importance sampling estimate:

In [0]:
def eval_perplexity(model, inference_model, eval_dataset, vocab, device, 
                    n_samples, batch_size=128):
    """
    Estimates the per-word perplexity using importance sampling with the
    given number of samples.
    """
    
    dl = DataLoader(eval_dataset, batch_size=batch_size)
    sorted_dl = SortingTextDataLoader(dl)
    
    # Make sure the model is in evaluation mode (i.e. disable dropout).
    model.eval()
    
    log_px = 0.
    num_predictions = 0
    num_words = 0
     
    # We don't need to compute gradients for this.
    with torch.no_grad():
        for words in sorted_dl:
            x_in, x_out, seq_mask, seq_len = create_batch(words, vocab, device)
            
            # Infer the approximate posterior and construct the prior.
            qz = inference_model(x_in, seq_mask, seq_len)
            pz = ProductOfBernoullis(torch.ones_like(qz.probs) * 0.5) # TODO different prior

            # Create an array to hold all samples for this batch.
            batch_size = x_in.size(0)
            log_px_samples = torch.zeros(n_samples, batch_size)
            
            # Sample log P(x) n_samples times.
            for s in range(n_samples):
                
                # Sample a z^s from the posterior.
                z = qz.sample()
                
                # Compute log P(x^k|z^s)
                px_z = model(x_in, z)
                # [B, T]
                cond_log_prob = px_z.log_prob(x_out)                
                cond_log_prob = torch.where(seq_mask, cond_log_prob, torch.zeros_like(cond_log_prob))
                # [B]
                cond_log_prob = cond_log_prob.sum(-1)
                
                # Compute log p(z^s) and log q(z^s|x^k)
                prior_log_prob = pz.log_pmf(z) # B
                posterior_log_prob = qz.log_pmf(z) # B
                
                # Store the sample for log P(x^k) importance weighted with p(z^s)/q(z^s|x^k).
                log_px_sample = cond_log_prob + prior_log_prob - posterior_log_prob
                log_px_samples[s] = log_px_sample
                
            # Average over the number of samples and count the number of predictions made this batch.
            log_px_batch = torch.logsumexp(log_px_samples, dim=0) - \
                    torch.log(torch.Tensor([n_samples]))
            log_px += log_px_batch.sum()
            num_predictions += seq_len.sum()
            num_words += seq_len.size(0)

    # Compute and return the perplexity per word.
    perplexity = torch.exp(-log_px / num_predictions)
    NLL = -log_px / num_words
    return perplexity, NLL

Lastly, we want to occasionally qualitatively see the performance of the model during training, by letting it reconstruct a given word from the latent space. This gives us an idea of whether the model is using the latent space to encode some semantics about the data. For this we use a deterministic greedy decoding algorithm, that chooses the word with maximum probability at every time step, and feeds that word into the next time step.

In [0]:
def greedy_decode(model, z, vocab, max_len=50):
    """
    Greedily decodes a word from a given z, by picking the word with
    maximum probability at each time step.
    """
    
    # Disable dropout.
    model.eval()
    
    # Don't compute gradients.
    with torch.no_grad():
        batch_size = z.size(0)
        
        # We feed the model the start-of-word symbol at the first time step.
        prev_x = torch.ones(batch_size, 1, dtype=torch.long).fill_(vocab[SOW_TOKEN]).to(z.device)
        
        # Initialize the hidden state from z.
        hidden = model.init_hidden(z)

        predictions = []    
        for t in range(max_len):
            logits, hidden = model.step(prev_x, z, hidden)
            
            # Choose the argmax of the unnnormalized probabilities as the
            # prediction for this time step.
            prediction = torch.argmax(logits, dim=-1)
            predictions.append(prediction)
            
            prev_x = prediction.view(batch_size, 1)
            
        return torch.cat(predictions, dim=1)

# Training

Now it's time to train the model. We use early stopping on the validation perplexity for model selection.

In [0]:
# Define the model hyperparameters.
emb_size = 256
hidden_size = 256 
latent_size = 16
bidirectional_encoder = True
free_nats = 0 # 5.
annealing_steps = 0 # 11400
dropout = 0.6
word_dropout = 0 # 0.75
batch_size = 64
learning_rate = 0.001
num_epochs = 20
n_importance_samples = 3 # 50

# Create the training data loader.
dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
sorted_dl = SortingTextDataLoader(dl)

# Create the generative model.
model = LatentFactorModel(vocab_size=vocab.size(), 
                 emb_size=emb_size, 
                 hidden_size=hidden_size, 
                 latent_size=latent_size, 
                 pad_idx=vocab[PAD_TOKEN],
                 dropout=dropout)
model = model.to(device)

# Create the inference model.
inference_model = InferenceModel(vocab_size=vocab.size(),
                                 embedder=model.embedder,
                                 hidden_size=hidden_size,
                                 latent_size=latent_size,
                                 pad_idx=vocab[PAD_TOKEN],
                                 bidirectional=bidirectional_encoder)
inference_model = inference_model.to(device)

# Create the optimizer.
optimizer = optim.Adam(itertools.chain(model.parameters(), 
                                       inference_model.parameters()), 
                       lr=learning_rate)

# Save the best model (early stopping).
best_model = "./best_model.pt"
best_val_ppl = float("inf")
best_epoch = 0

# Keep track of some statistics to plot later.
train_ELBOs = []
train_KLs = []
val_ELBOs = []
val_KLs = []
val_perplexities = []
val_NLLs = []

step = 0
training_ELBO = 0.
training_KL = 0.
num_batches = 0
for epoch_num in range(1, num_epochs+1):    
    for words in sorted_dl:

        # Make sure the model is in training mode (for dropout).
        model.train()

        # Transform the words to input, output, seq_len, seq_mask batches.
        x_in, x_out, seq_mask, seq_len = create_batch(words, vocab, device,
                                                      word_dropout=word_dropout)

        # Compute the multiplier for the KL term if we do annealing.
        if annealing_steps > 0:
            KL_weight = min(1., (1.0 / annealing_steps) * step)
        else:
            KL_weight = 1.
        
        # Do a forward pass through the model and compute the training loss. We use
        # a reparameterized sample from the approximate posterior during training.
        qz = inference_model(x_in, seq_mask, seq_len)
        pz = ProductOfBernoullis(torch.ones_like(qz.probs) * 0.5)
        z = qz.sample()
        px_z = model(x_in, z)
        loss, ELBO, KL = model.loss(px_z, x_out, pz, qz, z, free_nats=free_nats)        

        # Backpropagate and update the model weights.
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Update some statistics to track for the training loss.
        training_ELBO += ELBO
        training_KL += KL
        num_batches += 1
        
        # Every 100 steps we evaluate the model and report progress.
        if step % 100 == 0:
            val_ELBO, val_KL = eval_elbo(model, inference_model, val_dataset, vocab, device)            
            print("(%d) step %d: training ELBO (KL) = %.2f (%.2f) --"
                  " KL weight = %.2f --"
                  " validation ELBO (KL) = %.2f (%.2f)" % 
                  (epoch_num, step, training_ELBO/num_batches, 
                   training_KL/num_batches, KL_weight, val_ELBO, val_KL))
            
            # Update some statistics for plotting later.
            train_ELBOs.append((step, (training_ELBO/num_batches).item()))
            train_KLs.append((step, (training_KL/num_batches).item()))
            val_ELBOs.append((step, val_ELBO.item()))
            val_KLs.append((step, val_KL.item()))
            
            # Reset the training statistics.
            training_ELBO = 0.
            training_KL = 0.
            num_batches = 0
            
        step += 1

    # After an epoch we'll compute validation perplexity and save the model
    # for early stopping if it's better than previous models.
    print("Finished epoch %d" % (epoch_num))
    val_perplexity, val_NLL = eval_perplexity(model, inference_model, val_dataset, vocab, device, 
                                              n_importance_samples)
    val_ELBO, val_KL = eval_elbo(model, inference_model, val_dataset, vocab, device)    
    
    # Keep track of the validation perplexities / NLL.
    val_perplexities.append((epoch_num, val_perplexity.item()))
    val_NLLs.append((epoch_num, val_NLL.item()))
    
    # If validation perplexity is better, store this model for early stopping.
    if val_perplexity < best_val_ppl:
        best_val_ppl = val_perplexity
        best_epoch = epoch_num
        torch.save(model.state_dict(), best_model)
        
    # Print epoch statistics.
    print("Evaluation epoch %d:\n"
          " - validation perplexity: %.2f\n"
          " - validation NLL: %.2f\n"
          " - validation ELBO (KL) = %.2f (%.2f)"
          % (epoch_num, val_perplexity, val_NLL, val_ELBO, val_KL))

    # Also show some qualitative results by reconstructing a word from the
    # validation data. Use the mean of the approximate posterior and greedy
    # decoding.
    random_word = val_dataset[np.random.choice(len(val_dataset))]
    x_in, _, seq_mask, seq_len = create_batch([random_word], vocab, device)
    qz = inference_model(x_in, seq_mask, seq_len)
    z = qz.mean()
    reconstruction = greedy_decode(model, z, vocab)
    reconstruction = batch_to_words(reconstruction, vocab)[0]
    print("-- Original word: \"%s\"" % random_word)
    print("-- Model reconstruction: \"%s\"" % reconstruction)

(1) step 0: training ELBO (KL) = -39.02 (0.43) -- KL weight = 1.00 -- validation ELBO (KL) = -38.29 (0.43)
(1) step 100: training ELBO (KL) = -27.68 (1.20) -- KL weight = 1.00 -- validation ELBO (KL) = -23.76 (1.28)
Finished epoch 1
Evaluation epoch 1:
 - validation perplexity: 7.88
 - validation NLL: 21.97
 - validation ELBO (KL) = -22.52 (1.25)
-- Original word: "interpretarían"
-- Model reconstruction: "acontaren"
(2) step 200: training ELBO (KL) = -24.03 (1.33) -- KL weight = 1.00 -- validation ELBO (KL) = -22.47 (1.23)
(2) step 300: training ELBO (KL) = -23.19 (1.33) -- KL weight = 1.00 -- validation ELBO (KL) = -22.19 (1.47)
Finished epoch 2
Evaluation epoch 2:
 - validation perplexity: 7.41
 - validation NLL: 21.32
 - validation ELBO (KL) = -21.99 (1.57)
-- Original word: "subtítulos"
-- Model reconstruction: "acarrarían"
(3) step 400: training ELBO (KL) = -23.07 (1.66) -- KL weight = 1.00 -- validation ELBO (KL) = -22.02 (1.65)
(3) step 500: training ELBO (KL) = -23.00 (1.85) -

# Let's plot the training and validation statistics:

In [0]:
steps, training_ELBO = list(zip(*train_ELBOs))
_, training_KL = list(zip(*train_KLs))
_, val_ELBO = list(zip(*val_ELBOs))
_, val_KL = list(zip(*val_KLs))
epochs, val_ppl = list(zip(*val_perplexities))
_, val_NLL = list(zip(*val_NLLs))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5))

# Plot training ELBO and KL
ax1.set_title("Training ELBO")
ax1.plot(steps, training_ELBO, "-o")
ax2.set_title("Training KL")
ax2.plot(steps, training_KL, "-o")
plt.show()

# Plot validation ELBO and KL
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5))
ax1.set_title("Validation ELBO")
ax1.plot(steps, val_ELBO, "-o", color="orange")
ax2.set_title("Validation KL")
ax2.plot(steps, val_KL, "-o",  color="orange")
plt.show()

# Plot validation perplexities.
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5))
ax1.set_title("Validation perplexity")
ax1.plot(epochs, val_ppl, "-o", color="orange")
ax2.set_title("Validation NLL")
ax2.plot(epochs, val_NLL, "-o",  color="orange")
plt.show()
print()

Let's load the best model according to validation perplexity and compute its perplexity on the test data:

In [0]:
# Load the best model from disk.
model = LatentFactorModel(vocab_size=vocab.size(), 
                 emb_size=emb_size, 
                 hidden_size=hidden_size, 
                 latent_size=latent_size, 
                 pad_idx=vocab[PAD_TOKEN],
                 dropout=dropout)
model.load_state_dict(torch.load(best_model))
model = model.to(device)

# Compute test perplexity and ELBO.
test_perplexity, test_NLL = eval_perplexity(model, inference_model, test_dataset, vocab, 
                                            device, n_importance_samples)
test_ELBO, test_KL = eval_elbo(model, inference_model, test_dataset, vocab, device)
print("test ELBO (KL) = %.2f (%.2f) -- test perplexity = %.2f -- test NLL = %.2f" % 
      (test_ELBO, test_KL, test_perplexity, test_NLL))

test ELBO (KL) = -25.34 (5.46) -- test perplexity = 9.56 -- test NLL = 24.05


# Qualitative analysis

Let's have a look at what how our trained model interacts with the learned latent space. First let's greedily decode some samples from the prior to assess the diversity of the model:

In [0]:
# Generate 10 samples from the standard normal prior.
num_prior_samples = 10
pz = ProductOfBernoullis(torch.ones(num_prior_samples, latent_size) * 0.5)
z = pz.sample()
z = z.to(device)

# Use the greedy decoding algorithm to generate words.
predictions = greedy_decode(model, z, vocab)
predictions = batch_to_words(predictions, vocab)
for num, prediction in enumerate(predictions):
    print("%d: %s" % (num+1, prediction))

Let's now have a look how good the model is at reconstructing words from the test dataset using the approximate posterior mean and a couple of samples:

In [0]:
# Pick a random test word.
test_word = test_dataset[np.random.choice(len(test_dataset))]

# Infer q(z|x).
x_in, _, seq_mask, seq_len = create_batch([test_word], vocab, device)
qz = inference_model(x_in, seq_mask, seq_len)

# Decode using the mean.
z_mean = qz.mean()
mean_reconstruction = greedy_decode(model, z_mean, vocab)
mean_reconstruction = batch_to_words(mean_reconstruction, vocab)[0]

print("Original: \"%s\"" % test_word)
print("Posterior mean reconstruction: \"%s\"" % mean_reconstruction)

# Decode a couple of samples from the approximate posterior.
for s in range(3):
    z = qz.sample()
    sample_reconstruction = greedy_decode(model, z, vocab)
    sample_reconstruction = batch_to_words(sample_reconstruction, vocab)[0]
    print("Posterior sample reconstruction (%d): \"%s\"" % (s+1, sample_reconstruction))

We can also qualitatively assess the smoothness of the learned latent space by interpolating between two words in the test set:

In [0]:
# Pick a random test word.
test_word_1 = test_dataset[np.random.choice(len(test_dataset))]

# Infer q(z|x).
x_in, _, seq_mask, seq_len = create_batch([test_word_1], vocab, device)
qz = inference_model(x_in, seq_mask, seq_len)
qz_1 = qz.mean()

# Pick a random second test word.
test_word_2 = test_dataset[np.random.choice(len(test_dataset))]

# Infer q(z|x) again.
x_in, _, seq_mask, seq_len = create_batch([test_word_2], vocab, device)
qz = inference_model(x_in, seq_mask, seq_len)
qz_2 = qz.mean()

# Now interpolate between the two means and generate words between those.
num_words = 5
print("Word 1: \"%s\"" % test_word_1)
for alpha in np.linspace(start=0., stop=1., num=num_words):
    z = (1-alpha) * qz_1 + alpha * qz_2
    reconstruction = greedy_decode(model, z, vocab)
    reconstruction = batch_to_words(reconstruction, vocab)[0]
    print("(1-%.2f) * qz1.mean + %.2f qz2.mean: \"%s\"" % (alpha, alpha, reconstruction))
print("Word 2: \"%s\"" % test_word_2)