# Building a VAE in TensorFlow


Consider the following graphical model:

![VAE](img/vae.png)

where 
* $Z \sim \mathcal N(0, I)$ is a random embedding
* $X \sim \mathrm{Cat}(f_\theta(z))$ is a categorical distribution over a language's vocabulary
* $f_\theta(z) = \mathrm{softmax(g_\theta(z))}$ is a FFNN that predicts the parameters of our CPD

Here is our joint distribution 
\begin{align}
p_\theta(x, z) &= p_\theta(z) P_\theta(x|z) \\
 &= \mathcal N(0, I) P_\theta(x|z)
\end{align}

Note that the marginal likelihood is intractable

\begin{align}
    P_\theta(x) &=  \int p_\theta(z,x) \mathrm{d}z \\
    &= p_\theta(z)P_\theta(x|z) \\
\end{align}

because of the marginalisation over all possible random embeddings and this makes our posterior intractable too.

## Training

We will use variational inference to circumvent the intractable marginalisation, where we propose a variational approximation (a.k.a. *inference network*) $q_\phi(z|x)$ with its own parameters $\phi$.
Since $Z$ is Gaussian-distributed, we choose $q_\phi(z|x) = \mathcal N(\mu_\phi(x), \sigma^2_\phi(x))$, where

* $\mu_\phi(x) = u_\phi(x)$
* $\sigma^2_\phi(x) = \exp(s_\phi(x))$

are FFNNs that locally predict an approximation to the true posterior mean and variance for each observation $x$.

Our variational auto-encoder then boils down to:

* an *inference network*, i.e., a neural network that 
    * reads in words
    * embeds them
    * for each word: 
        * predicts a vector of means $\mu_\phi(x)$
        * predicts a vector of (log) variances $\sigma_\phi^2(x)$
        * samples a random embedding by sampling $\epsilon \sim \mathcal N(0, I)$ and returning $\mu_\phi(x) + \epsilon \sigma_\phi(x)$

* a *generative model*, i.e., a neural network that for each word position
    * takes a sampled embedding $z$
    * predicts the parameters of a categorical distribution over the vocabulary $f_\theta(x)$
    
You will identify all these steps in the code.

The model is trained to maximise a lowerbound on log-likelihood of training data, the ELBO:

\begin{align}
\mathcal E_{\mathcal D}(\theta, \phi) &= \frac{1}{|D|} \sum_{x_1^n \in \mathcal D} \underbrace{\sum_{i=1}^{n} \underbrace{\mathcal E(\theta, \phi|x_i)}_{\text{word}}}_{\text{sentence}}
\end{align}

where $\mathcal D$ is a set made of $|\mathcal D|$ sentences, each of which is itself a sequence of words.
The contribution to the ELBO due to each sentence is the sum of contributions from each word:

\begin{align}
\mathcal E(\theta, \phi|x) &= \mathbb E_{q_\phi(Z|x)} \left[ \log P_\theta(x|Z) \right] - \mathrm{KL}(q_\phi(Z|x)||p_\theta(z)) \\
 &= \mathbb E_{\epsilon \sim \mathcal N(0,I)} \left[ \log P_\theta(x|Z=\mu_\phi(x) + \epsilon \sigma_\phi(x)) \right] - \mathrm{KL}(q_\phi(Z|x)||\mathcal N(0, I)) 
\end{align}

which we usually approximate with a single sample $\epsilon \sim N(0, I)$ for each word $x$

\begin{align}
\mathcal E(\theta, \phi|x) 
 &\approx \mathbb \log P_\theta(x|Z=\mu_\phi(x) + \epsilon \sigma_\phi(x)) - \mathrm{KL}(q_\phi(Z|x)||\mathcal N(0, I)) 
\end{align}

and the KL term can be computed analytically

\begin{align}
\mathrm{KL}(q(Z|x)||\mathcal N(0, I)) &= -\frac{1}{2} \sum_{j=1}^d \left( 1 + \log \sigma^2_{\phi,j}(x) - \mu^2_{\phi,j}(x) - \sigma^2_{\phi,j}(x) \right)
\end{align}

where the summation is defined over the $d$ components of the mean and variance vectors.


## Posterior Inference

Note that in general, because the generative model involves non-linear functions of $Z$

\begin{align}
\mathbb E_{p_\theta(Z|X)}[ f(Z) ]  & \neq f\left(\mathbb E_{p_\theta(Z|X)}[Z] \right)
\end{align}

where $p(Z|X=x)$ is approximated by our variational distribution $q_\phi(Z|X=x)$.

This means that decoding the mean is not the same as the mean decoding for a certain decoder $f$.

Nonetheless, we will make a simplifying assumption here and approximate $\mathbb E_{p_\theta(Z|X=x)}[Z]$ by the predicted mean $\mu_\phi(x)$.

A more principled approach would sample a few times from the approximate posterior and use a stochastic decoder (e.g. MBR), but this is beyond the scope of project 3.

In [1]:
# first run a few imports:
%load_ext autoreload
%autoreload 2
  
import tensorflow as tf
import numpy as np
import tempfile
import gzip
import pickle
import random
from collections import Counter, OrderedDict
from aer import read_naacl_alignments, AERSufficientStatistics

### Let's first load some data

We define a reader that returns one sentence at a time, without loading the whole data set into memory.
This is done using the "yield" command.

In [2]:
from utils import smart_reader, filter_len


def reader_test(path):
  # corpus is now a generator that gives us a list of tokens (a sentence) 
  # everytime a function calls "next" on it
  corpus = filter_len(smart_reader(train_en_path), max_length=10)

  # to see that it really works, try this:
  print(next(corpus))
  print(next(corpus))
  print(next(corpus))
  
  
# the path to our training data, English side
train_en_path = 'data/training/hansards.36.2.e.gz'

# Let's try it:
reader_test(train_en_path)

['36', 'th', 'Parliament', ',', '2', 'nd', 'Session']
['edited', 'HANSARD', '*', 'NUMBER', '1']
['contents']


### Now, let's create a vocabulary!

We first define a class `Vocabulary` that helps us convert tokens (words) into numbers. This is useful later, because then we can e.g. index a word embedding table using the ID of a word.

In [3]:
from vocabulary import Vocabulary

Now let's try out our Vocabulary class:

In [4]:
# We used up a few lines in the previous example, so we set up
# our data generator again.
corpus = smart_reader(train_en_path)    

# Let's create a vocabulary given our (tokenized) corpus
vocabulary = Vocabulary(corpus=corpus)
print("Original vocabulary size: {}".format(len(vocabulary)))

# Now we only keep the highest-frequency words
vocabulary_size=10000
vocabulary.trim(vocabulary_size)
print("Trimmed vocabulary size: {}".format(len(vocabulary)))

# Now we can get word indexes using v.get_word_id():
for t in ["*PAD*", "<UNK>", "the"]:
  print("The index of \"{}\" is: {}".format(t, vocabulary.get_token_id(t)))

# And the inverse too, using v.i2t:
for i in range(10):
  print("The token with index {} is: {}".format(i, vocabulary.get_token(i)))

# Now let's try to get a word ID for a word not in the vocabulary
# we should get 1 (so, <UNK>)
for t in ["!@!_not_in_vocab_!@!"]:
  print("The index of \"{}\" is: {}".format(t, vocabulary.get_token_id(t)))

Original vocabulary size: 36640
Trimmed vocabulary size: 10005
The index of "*PAD*" is: 1
The index of "<UNK>" is: 1
The index of "the" is: 5
The token with index 0 is: <PAD>
The token with index 1 is: <UNK>
The token with index 2 is: <S>
The token with index 3 is: </S>
The token with index 4 is: <NULL>
The token with index 5 is: the
The token with index 6 is: .
The token with index 7 is: ,
The token with index 8 is: of
The token with index 9 is: to
The index of "!@!_not_in_vocab_!@!" is: 1


### Mini-batching

With our vocabulary, we still need a method that converts a whole sentence to a sequence of IDs.
And, to speed up training, we would like to get a so-called mini-batch at a time: multiple of such sequences together. So our function takes a corpus iterator and a vocabulary, and returns a mini-batch of dimension Batch X Time, where the first dimension indeces the sentences in the batch, and the second the time steps in each sentence. 

In [5]:
from utils import iterate_minibatches, prepare_batch_data

In [6]:
# Let's try it out!
corpus = smart_reader(train_en_path)          


for batch_id, batch in enumerate(iterate_minibatches(corpus, batch_size=4)):

  print("This is the batch of data that we will train on, as tokens:")
  print(batch)
  print()

  x = prepare_batch_data(batch, vocabulary)

  print("These are our inputs (i.e. words replaced by IDs):")
  print(x)
  print()
  
  print("Here is the original first sentence back again:")
  print([vocabulary.get_token(token_id) for token_id in x[0]])

  break  # stop after the first batch, this is just a demonstration

This is the batch of data that we will train on, as tokens:
[['36', 'th', 'Parliament', ',', '2', 'nd', 'Session'], ['edited', 'HANSARD', '*', 'NUMBER', '1'], ['contents'], ['Tuesday', ',', 'October', '12', ',', '1999']]

These are our inputs (i.e. words replaced by IDs):
[[   4 1203  745  325    7  262 2381 1963]
 [   4 2651 2665   67 2643  238    0    0]
 [   4 2873    0    0    0    0    0    0]
 [   4 1532    7  813  882    7  297    0]]

Here is the original first sentence back again:
['<NULL>', '36', 'th', 'Parliament', ',', '2', 'nd', 'Session']


Now, notice the following:

1. The longest sequence in the batch has no padding. Any sequences shorter, however, will have padding zeros.
2. The length tensor gives the length for each sequence in the batch, so that we can correctly calculate the loss.

With our input pipeline in place, now let's create a model.

### Building our model


In [7]:
# check vae.py to see the model
from vae import VAE

### Training the model

Now that we have a model, we need to train it. To do so we define a Trainer class that takes our model as an argument and trains it, keeping track of some important information.



In [None]:
from vae_trainer import VAETrainer

tf.reset_default_graph()

with tf.Session() as sess:
#   with tf.device("/cpu:0"):   

  batch_size=64
  max_length=30

  model = VAE(vocabulary=vocabulary, batch_size=batch_size, 
              emb_dim=64, rnn_dim=128, z_dim=64)
  trainer = VAETrainer(model, train_en_path, num_epochs=10, 
                  batch_size=batch_size, max_length=max_length,
                  lr=0.001, lr_decay=0.0, session=sess)

  print("Initializing variables..")
  sess.run(tf.global_variables_initializer())

  print("Training started..")
  trainer.train()

Initializing variables..
Training started..
Iter 100 loss 990.9293823242188 ce 395.82684326171875 kl 595.1025390625 acc 0.15 127/842 lr 0.001000
Iter 200 loss 648.8838500976562 ce 358.5419921875 kl 290.34185791015625 acc 0.40 379/937 lr 0.001000
Iter 300 loss 505.9349365234375 ce 317.25982666015625 kl 188.6751251220703 acc 0.49 457/942 lr 0.001000
Iter 400 loss 363.85870361328125 ce 235.06683349609375 kl 128.7918701171875 acc 0.54 420/783 lr 0.001000
Iter 500 loss 387.63946533203125 ce 266.2981872558594 kl 121.34127807617188 acc 0.55 492/889 lr 0.001000
Iter 600 loss 369.08770751953125 ce 258.7518615722656 kl 110.33585357666016 acc 0.59 546/926 lr 0.001000
Iter 700 loss 296.1138000488281 ce 205.559814453125 kl 90.55398559570312 acc 0.60 462/774 lr 0.001000
Iter 800 loss 359.9495544433594 ce 254.54965209960938 kl 105.39990234375 acc 0.62 614/984 lr 0.001000
Iter 900 loss 310.8501892089844 ce 218.02670288085938 kl 92.823486328125 acc 0.65 582/892 lr 0.001000
Iter 1000 loss 644.7874755859