# How to implement a Variational Autoencoder (VAE)

A variational autoencoder observes data, infers a latent code for it and tries to reconstruct the data from that latent code. In contrast to regular autoencoders, the code of the VAE is **random**. That means that when presented with the same input, the VAE will produce a slightly different code each time. This makes its decoding process more robust, since it has to deal with noisy code.

Another way of looking at a VAE is as a training procedure for a probablistic model. The model is 
$$p(x) = \int p(z)p(x|z) dz$$
where $z$ is the latent code and $x$ is the data. During training we need to infer a posterior over $z$. In the case of a VAE this is done by neural network.

Assuming that the theory of VAEs has already been presented, we now dive straight into implementing them. If you need more background on VAEs, have a look at our [tutorial slides](https://github.com/philschulz/VITutorial/tree/master/modules) and the references therein.

# The Framework

For the purpose of this tutorial we are going to use [mxnet](https://mxnet.incubator.apache.org) which is a scalable deep learning library that has interfaces for several languages, including python. We are going to import and abbreviate it as "mx". We will use mxnet to define computation graph. This is done using the [symbol library](https://mxnet.incubator.apache.org/api/python/symbol.html). When building the VAE, all the methods that you use should be prefixed with `mx.sym`.

In [None]:
import mxnet as mx
import urllib.request
import os, logging, sys
from os.path import join, exists
from abc import ABC
from typing import List, Tuple, Callable

Next, we specify a couple of variables that will help us to load the data.

In [None]:
DEFAULT_LEARNING_RATE = 0.0003

data_names = ['train', 'valid', 'test']
train_set = ['train', 'valid']
test_set = ['test']
data_dir = join(os.curdir, "binary_mnist")
data_names = ['train', 'valid', 'test']

Finally, we set up basic logging facilities to print intermediate output.

In [None]:
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s]: %(message)s")

# The Data

Throughout the tutorial we will use the binarised MNIST data set consisting of images of handwritten digits (0-9). Each pixel has been mapped to either 0 or 1, meaning that pixels are either fully on or off. We use this data set because it allows us to use a rather simple product of Bernoullis as a likelihood. We download the data into a folder called "binary_mnist".

In [None]:
if not exists(data_dir):
    os.mkdir(data_dir)
for data_set in data_names:
    file_name = "binary_mnist.{}".format(data_set)
    goal = join(data_dir, file_name)
    if exists(goal):
        logging.info("Data file {} exists".format(file_name))
    else:
        logging.info("Downloading {}".format(file_name))
        link = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat".format(
            data_set)
        urllib.request.urlretrieve(link, goal)
        logging.info("Finished")

Alright, now we have the data on disk. We will load it later for training and testing. But first, we need to build our VAE.

# Diagonal Gaussian VAE

The most basic VAE model is one where we assume that the latent variable is multiviariate Gaussian. We fix the prior to be standard normal. During inference, we use a multivariate Gaussian variational distribution with diagonal covariance matrix. This means that we are only modelling variance but not covariance (in fact, a k-dimensional Guassian with diagonal covariance has the same density as a product of k independent univariate Gaussians). Geometrically, this variational distribution can only account for spherical but not for eliptical densities. It is thus rather limited in its modelling capabilities. Still, because it uses a neural network under the hood, it is very expressive. 

In this tutorial, we will model the mist binarised digit data set. Each image is encoded as a 784-dimensional vector. We will model each of these vectors as a product of 784 Bernoullis (of course, there are better models but we want to keep it simple). Our likelihood is thus a product of independent Bernoullis. The resulting model is formally specified as 

\begin{align}z \sim \mathcal{N}(0,I) && x_i|z \sim Bernoulli(NN_{\theta}(z))~~~ i \in \{1,2,\ldots, 784\} \ .\end{align}

The variational approximation is given by $$q(z|x) = \mathcal{N}(NN_{\lambda}(x), NN_{\lambda}(x)).$$

Notice that both the Bernoulli likelihood and the Gaussian variational distribution use NNs to compute their parameters. The parameters of the NNs, however, are different ($\theta$ and $\lambda$, respectively).

## Implementation

We will spread our implementation across 3 classes. This design choice is motivated by the desire to make our models as modular as possible. This will later allow us to mix and match different likelihoods and variational distributions.

* **Generator**: This class defines our likelihood. Given a latent value, it will can produce a data sample our assign a density to an existing data point.
* **InferenceNetwork**: This neural network computes the parameters of the variational approximation from a data point.
* **VAE**: This is the variational autoencoder. It combines a Generator and an InferenceNetwork and trains them jointly. Once trained, it can generate random data points or try to reproduce data presented to it.

Below we have specified these classes abstractly. Make sure you understand what each method is supposed to be doing.

In [None]:
class Generator(ABC):
    """
    Generator network.

    :param data_dims: Dimensionality of the generated data.
    :param layer_sizes: Size of each layer in the network.
    :param act_type: The activation after each layer.
    """

    def __init__(self, data_dims: int, layer_sizes: List[int], act_type: str) -> None:
        self.data_dims = data_dims
        self.layer_sizes = layer_sizes
        self.act_type = act_type

    def generate_sample(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Generate a data sample from a latent state.

        :param latent_state: The latent input state.
        :return: A data sample.
        """
        raise NotImplementedError()

    def train(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Train the generator from a given latent state.
        
        :param latent_state: The latent input state
        :return: The loss symbol used for training
        """
        raise NotImplementedError()

        
class InferenceNetwork(ABC):
    """
    A network to infer distributions over latent states.

    :param latent_variable_size: The dimensionality of the latent variable.
    :param layer_sizes: Size of each layer in the network.
    :param act_type: The activation after each layer.
    """

    def __init__(self, latent_variable_size: int, layer_sizes: List[int], act_type: str) -> None:
        self.latent_var_size = latent_variable_size
        self.layer_sizes = layer_sizes
        self.act_type = act_type

    def inference(self, data: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, ...]:
        """
        Infer the parameters of the distribution over latent values.

        :param data: A data sample.
        :return: The parameters of the distribution.
        """
        raise NotImplementedError()
        
        
class VAE(ABC):
    """
    A variational autoencoding model (Kingma and Welling, 2013).

    :param generator: A generator network that specifies the likelihood of the model.
    :param inference_net: An inference network that specifies the distribution over latent values.
    """

    def __init__(self, generator: Generator, inference_net: InferenceNetwork) -> None:
        self.generator = generator
        self.inference_net = inference_net

    def train(self, data: mx.sym.Symbol, label: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Train the generator and inference network jointly by optimising the ELBO.

        :param data: The training data.
        :param label: Copy of the training data.
        :return: A list of loss symbols.
        """
        raise NotImplementedError()

    def generate_reconstructions(self, data: mx.sym.Symbol, n: int) -> mx.sym.Symbol:
        """
        Generate a number of reconstructions of input data points.

        :param data: The input data.
        :param n: Number of reconstructions per data point.
        :return: The reconstructed data.
        """
        raise NotImplementedError()

    def phantasize(self, n: int) -> mx.sym.Symbol:
        """
        Generate data by randomly sampling from the prior.

        :param n: Number of sampled data points.
        :return: Randomly generated data points.
        """
        raise NotImplementedError()

## Exercise 1
Let us start by implementing the generator. This is pretty much a standard neural network. The main point of this exercise is to get you comfortable with mxnet. Complete all the TODOs below. Before starting, check the activation
functions available in mxnet [here](https://mxnet.incubator.apache.org/api/python/symbol.html#mxnet.symbol.Activation).

In [None]:
class ProductOfBernoullisGenerator(Generator):
    """
    A generator that produces binary vectors whose entries are independent Bernoulli draws.

    :param data_dims: Dimensionality of the generated data.
    :param layer_sizes: Size of each layer in the network.
    :param act_type: The activation after each layer.
    """

    def __init__(self, data_dims: int, layer_sizes=List[int], act_type=str) -> None:
        super().__init__(data_dims, layer_sizes, act_type)
        # TODO choose the correct output activation for a Bernoulli variable. This should just be a string.
        self.output_act = "sigmoid"

    def _preactivation(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Computes the pre-activation of the generator, i.e. the hidden state before the final output activation.

        :param latent_state: The input latent state
        :return: The pre-activation before output activation
        """
        prev_out = None
        for i, hidden in enumerate(self.layer_sizes):
            fc_i = mx.sym.FullyConnected(data=latent_state, num_hidden=hidden, name="gen_fc_{}".format(i))
            act_i = mx.sym.Activation(data=fc_i, act_type=self.act_type, name="gen_act_{}".format(i))
            prev_out = act_i

        # The output layer that gives pre_activations for multiple Bernoulli softmax between 0 and 1
        fc_out = mx.sym.FullyConnected(data=prev_out, num_hidden=2 * self.data_dims, name="gen_fc_out")

        return fc_out
    
    def generate_sample(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Generates a data sample by picking producing the maximally likely outcome. The stochasticity in the sampling
        process comes from the latent_state.

        :param latent_state: The input latent state.
        :return: A vector of Bernoulli draws.
        """
        act = mx.sym.Activation(data=self._generate(latent_state=latent_state), act_type=self.output_act,
                                name="gen_act_out")
        act = mx.ndarray(mx.sym.split(data=act, num_outputs=self.data_dims))
        out = mx.sym.maximum(data=act, axis=0)

        return out

    def train(self, latent_state=mx.sym.Symbol, label=mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Train the generator from a given latent state

        :param latent_state: The input latent state
        :param label: A binary vector (same as input for inference module)
        :return: The loss symbol used for training
        """
        output = self._preactivation(latent_state=latent_state)
        output = mx.sym.reshape(data=output, shape=(-1, 2, self.data_dims))
        # We use a multi-ouput softmax. This is computes gradients for 784 independent softmaxes.
        # Since a softmax of a 2-dim vector is a logistic function, we get the likelihood (and gradients)
        # for 784 independent Bernoulli distributions as desired.
        return mx.sym.SoftmaxOutput(data=output, label=label, multi_output=True)

## Exercise 2

We now move on to the inference network. Recall that this network will return the parameters of a diagonal Gaussian. Thus, we need to return to vectors of the same size: a mean and a standard deviation vector. (Formally, the parameters of the Gaussian are the variances. However, from the derivation of the Gaussian reparametrisation we know that we
need the standard deviations to generate a Gaussian random variable $z$ as transformation of a standard Gaussian variable $\epsilon$.)

**Hint:** In this exercise you will need to draw a random Gaussian sample (see [here](https://mxnet.incubator.apache.org/api/python/symbol.html#mxnet.symbol.random_normal)). The operator requires are
shape whose first entry is the batch size. The batch size is not known to you during implementation, however.
You can leave it underspecified by choosing $0$ as a value. When you combine the sampling operator with another
operator immediately, mxnet will infer the correct the batch size for you.

In [None]:
class GaussianInferenceNetwork(InferenceNetwork):
    """
    An inference network that predicts the parameters of a diagonal Gaussian and samples from that distribution.

    :param latent_variable_size: The dimensionality of the latent variable.
    :param layer_sizes: Size of each layer in the network.
    :param act_type: The activation after each layer.
    """

    def __init__(self, latent_variable_size: int, layer_sizes: List[int], act_type: str):
        super().__init__(latent_variable_size, layer_sizes, act_type)

    def inference(self, data: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]:
        """
        Infer the mean and standard deviation.

        :param data: A data sample.
        :return: The mean and standard deviation.
        """
        # We choose to share the first layer between the networks that compute the standard deviations
        # and means. This is a fairly standard design choice.
        shared_layer = mx.sym.FullyConnected(data=data, num_hidden=self.layer_sizes[0], name="inf_joint_fc")
        shared_layer = mx.sym.Activation(data=shared_layer, act_type=self.act_type, name="inf_joint_act")

        prev_out = shared_layer
        for i, size in enumerate(self.layer_sizes[1:]):
            mean_fc_i = mx.sym.FullyConnected(data=prev_out, num_hidden=size, name="inf_mean_fc_{}".format(i))
            mean_act_i = mx.sym.Activation(data=mean_fc_i, act_type=self.act_type, name="inf_mean_act_{}".format(i))
            prev_out = mean_act_i
        mean = mx.sym.FullyConnected(data=prev_out, num_hidden=self.latent_var_size, name="inf_mean_compute")

        prev_out = shared_layer
        for i, size in enumerate(self.layer_sizes[1:]):
            var_fc_i = mx.sym.FullyConnected(data=prev_out, num_hidden=size, name="rec_var_fc_{}".format(i))
            var_act_i = mx.sym.Activation(data=var_fc_i, act_type=self.act_type, name="rec_var_act_{}".format(i))
            prev_out = var_act_i
        # soft-relu maps std onto non-negative real line
        std = mx.sym.Activation(
            mx.sym.FullyConnected(data=prev_out, num_hidden=self.latent_var_size, name="inf_var_compute"),
            act_type="softrelu")

        return mean, std

    def sample_latent_state(self, mean: mx.sym.Symbol, std: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Sample a latent Gaussian variable

        :param mean: The mean of the Gaussian
        :param std: The standard deviation of the Gaussian
        :return: A Gaussian sample
        """
        # TODO: This is where the magic happens! Draw a sample from the Gaussian using the Gaussian reparametrisation
        # trick and return it.
        return mean + std * mx.sym.random_normal(loc=0, scale=1, shape=(0, self.latent_var_size))


# Exercise 3.a

Finally, we will put it all together and build our VAE. Recall that the objective for the inference net contains a KL term. You will need to implement that KL-term. Once it is implemented, we can take advantage of autograd to get its gradients.

In [None]:
def diagonal_gaussian_kl(mean: mx.sym.Symbol, std: mx.sym.Symbol) -> mx.sym.Symbol:
    var = std ** 2
    return 0.5 * (mx.sym.sum(1 + mx.sym.log(var) - mean ** 2 - var))

# Exercise 3.b

The only thing that is left to do is to implement VAE training. This is where you will have to define an additional loss with respect to the KL divergence for the inference net. In mxnet losses are defined using the [MakeLoss](https://mxnet.incubator.apache.org/api/python/symbol.html#mxnet.symbol.MakeLoss) symbol.

In [None]:
class GaussianVAE(VAE):
    """
    A VAE with Gaussian latent variables. It assumes a standard normal prior on the latent values.

    :param generator: A generator network that specifies the likelihood of the model.
    :param inference_net: An inference network that specifies the Gaussian over latent values.
    """

    def __init__(self,
                 generator: Generator,
                 inference_net: GaussianInferenceNetwork,
                 kl_divergence: Callable) -> None:
        self.generator = generator
        self.inference_net = inference_net
        self.kl_divergence = kl_divergence

    def train(self, data: mx.sym.Symbol, label: mx.sym.Symbol) -> mx.sym.Symbol:
        """
        Train the generator and inference network jointly by optimising the ELBO.

        :param data: The training data.
        :param label: Copy of the training data.
        :return: A list of loss symbols.
        """
        mean, std = self.inference_net.inference(data=data)
        latent_state = self.inference_net.sample_latent_state(mean, std)
        kl_loss = mx.sym.MakeLoss(self.kl_divergence(mean, std))
        return mx.sym.Group([self.generator.train(latent_state=latent_state, label=label), kl_loss])

    def generate_reconstructions(self, data: mx.sym.Symbol, n: int) -> mx.sym.Symbol:
        """
        Generate a number of reconstructions of input data points.

        :param data: The input data.
        :param n: Number of reconstructions per data point.
        :return: The reconstructed data.
        """
        mean, std = self.inference_net.inference(data=data)
        mean = mx.sym.tile(data=mean, reps=(n, 1))
        std = mx.sym.tile(data=std, reps=(n, 1))
        latent_state = self.sample_latent_state(mean, std, n)
        return self.generator.generate_sample(latent_state=latent_state)

    def phantasize(self, n: int) -> mx.sym.Symbol:
        """
        Generate data by randomly sampling from the prior.

        :param n: Number of sampled data points.
        :return: Randomly generated data points.
        """
        latent_state = mx.sym.random_normal(loc=0, scale=1, shape=(n, self.inference_net.latent_var_size))
        return self.generator.generate_sample(latent_state=latent_state)

# Running the VAE

We have now fully specified a VAE. To get an impression of what its computation graph looks like, we can use mxnet's built-in visualisation. 

In [None]:
def construct_vae(latent_type: str,
                  likelihood: str,
                  generator_layer_sizes: List[int],
                  infer_layer_sizes: List[int],
                  latent_variable_size: int,
                  data_dims: int,
                  generator_act_type: str = "tanh",
                  infer_act_type: str = "tanh") -> VAE:
    """
    Construct a variational autoencoder

    :param latent_type: Distribution of latent variable.
    :param likelihood: Type of likelihood.
    :param generator_layer_sizes: Sizes of generator hidden layers.
    :param infer_layer_size: Sizes of inference network hidden layers.
    :param latent_variable_size: Size of the latent variable.
    :param data_dims: Dimensionality of the data.
    :param generator_act_type: Activation function for generator hidden layers.
    :param infer_act_type: Activation function for inference network hidden layers.
    :return: A variational autoencoder.
    """
    if likelihood == "bernoulliProd":
        generator = ProductOfBernoullisGenerator(data_dims=data_dims, layer_sizes=generator_layer_sizes,
                                                 act_type=generator_act_type)
    else:
        raise Exception("{} is an invalid likelihood type.".format(likelihood))

    if latent_type == "gaussian":
        inference_net = GaussianInferenceNetwork(latent_variable_size=latent_variable_size,
                                                 layer_sizes=infer_layer_sizes,
                                                 act_type=infer_act_type)
        return GaussianVAE(generator=generator, inference_net=inference_net, kl_divergence=diagonal_gaussian_kl)
    else:
        raise Exception("{} is an invalid latent variable type.".format(latent_type))


In [None]:
vae = construct_vae(latent_type="gaussian", likelihood="bernoulliProd", generator_layer_sizes=[200,500],
                   infer_layer_sizes=[500,200], latent_variable_size=100, data_dims=784, generator_act_type="tanh",
                   infer_act_type="tanh")

In [None]:
data = mx.sym.Variable("data")
label = mx.sym.Variable("label")
mx.viz.plot_network(vae.train(data, label), title="bla.jpg", save_format='jpg')