In [1]:
import numpy as np
import tensorflow as tf

from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/', one_hot=False)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [10]:
class ModelVAE(object):

    def __init__(self, x, h_dim, z_dim, activation=tf.nn.relu, distribution='normal'):
        """
        ModelVAE initializer

        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, activation, distribution

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)

    def _encoder(self, x):
        """
        Encoder network

        :param x: placeholder for input
        :return: tuple `(z_mean, z_var)` with mean and concentration around the mean
        """
        # dynamic binarization
        x = tf.cast(tf.greater(x, tf.random_uniform(shape=tf.shape(x), dtype=x.dtype)), dtype=x.dtype)
        
        # 2 hidden layers encoder
        h0 = tf.layers.dense(x, units=self.h_dim * 2, activation=self.activation)
        h1 = tf.layers.dense(h0, units=self.h_dim, activation=self.activation)

        if self.distribution == 'normal':
            # compute mean and std of the normal distribution
            z_mean = tf.layers.dense(h1, units=self.z_dim, activation=None)
            z_var = tf.layers.dense(h1, units=self.z_dim, activation=tf.nn.softplus)
        elif self.distribution == 'vmf':
            # compute mean and concentration of the von Mises-Fisher
            z_mean = tf.layers.dense(h1, units=self.z_dim, activation=lambda x: tf.nn.l2_normalize(x, axis=-1))
            # the `+ 1` prevent collapsing behaviors
            z_var = tf.layers.dense(h1, units=1, activation=tf.nn.softplus) + 1
        else:
            raise NotImplemented

        return z_mean, z_var

    def _decoder(self, z):
        """
        Decoder network

        :param z: tensor, latent representation of input (x)
        :return: logits, `reconstruction = sigmoid(logits)`
        """
        # 2 hidden layers decoder
        h2 = tf.layers.dense(z, units=self.h_dim, activation=self.activation)
        h2 = tf.layers.dense(h2, units=self.h_dim * 2, activation=self.activation)
        logits = tf.layers.dense(h2, units=self.x.shape[-1], activation=None)

        return logits


In [11]:
class OptimizerVAE(object):

    def __init__(self, model, learning_rate=1e-3):
        """
        OptimizerVAE initializer

        :param model: a model object
        :param learning_rate: float, learning rate of the optimizer
        """

        # binary cross entropy error
        self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
        self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))

        if model.distribution == 'normal':
            # KL divergence between normal approximate posterior and standard normal prior
            self.p_z = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=-1))
        elif model.distribution == 'vmf':
            # KL divergence between vMF approximate posterior and uniform hyper-spherical prior
            self.p_z = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(kl)
        else:
            raise NotImplemented

        self.ELBO = - self.reconstruction_loss - self.kl

        self.train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(-self.ELBO)

        self.print = {'recon loss': self.reconstruction_loss, 'ELBO': self.ELBO, 'KL': self.kl}

In [12]:
def log_likelihood(model, optimizer, n=10):
    """

    :param model: model object
    :param optimizer: optimizer object
    :param n: number of MC samples
    :return: MC estimate of log-likelihood
    """

    z = model.q_z.sample(n)

    log_p_z = optimizer.p_z.log_prob(z)

    if model.distribution == 'normal':
        log_p_z = tf.reduce_sum(log_p_z, axis=-1)

    log_p_x_z = -tf.reduce_sum(optimizer.bce, axis=-1)

    log_q_z_x = model.q_z.log_prob(z)

    if model.distribution == 'normal':
        log_q_z_x = tf.reduce_sum(log_q_z_x, axis=-1)

    return tf.reduce_mean(tf.reduce_logsumexp(
        tf.transpose(log_p_x_z + log_p_z - log_q_z_x) - np.log(n), axis=-1))

In [13]:

# hidden dimension and dimension of latent space
H_DIM = 128
Z_DIM = 2

# digit placeholder
x = tf.placeholder(tf.float32, shape=(None, 784))

# normal VAE
modelN = ModelVAE(x=x, h_dim=H_DIM, z_dim=Z_DIM, distribution='normal')
optimizerN = OptimizerVAE(modelN)

# hyper-spherical VAE
modelS = ModelVAE(x=x, h_dim=H_DIM, z_dim=Z_DIM + 1, distribution='vmf')
optimizerS = OptimizerVAE(modelS)

session = tf.Session()
session.run(tf.global_variables_initializer())

In [14]:

print('##### Normal VAE #####')
for i in range(1000):
    # training
    x_mb, _ = mnist.train.next_batch(64)
    session.run(optimizerN.train_step, {modelN.x: x_mb})

    # every 100 iteration plot validation
    if i % 100 == 0:
        x_mb = mnist.validation.images
        print(i, session.run({**optimizerN.print}, {modelN.x: x_mb}))

print('Test set:')
x_mb = mnist.test.images
print_ = {**optimizerN.print}
print_['LL'] = log_likelihood(modelN, optimizerN, n=100)
print(session.run(print_, {modelN.x: x_mb}))

print()
print('##### Hyper-spherical VAE #####')
for i in range(1000):
    # training
    x_mb, _ = mnist.train.next_batch(64)
    session.run(optimizerS.train_step, {modelS.x: x_mb})

    # every 100 iteration plot validation
    if i % 100 == 0:
        x_mb = mnist.validation.images
        print(i, session.run({**optimizerS.print}, {modelS.x: x_mb}))

print('Test set:')
x_mb = mnist.test.images
print_ = {**optimizerS.print}
print_['LL'] = log_likelihood(modelS, optimizerS, n=100)
print(session.run(print_, {modelS.x: x_mb}))


##### Normal VAE #####
0 {'KL': 0.16437618, 'recon loss': 539.3147, 'ELBO': -539.47906}
100 {'KL': 2.8173559, 'recon loss': 198.61513, 'ELBO': -201.43248}
200 {'KL': 3.7334108, 'recon loss': 183.73953, 'ELBO': -187.47295}
300 {'KL': 3.9201128, 'recon loss': 176.92807, 'ELBO': -180.84819}
400 {'KL': 4.118131, 'recon loss': 174.5309, 'ELBO': -178.64903}
500 {'KL': 4.4276586, 'recon loss': 171.1067, 'ELBO': -175.53436}
600 {'KL': 4.606196, 'recon loss': 168.9398, 'ELBO': -173.546}
700 {'KL': 4.4683576, 'recon loss': 167.64552, 'ELBO': -172.11388}
800 {'KL': 4.7288294, 'recon loss': 166.15912, 'ELBO': -170.88795}
900 {'KL': 4.636322, 'recon loss': 164.90408, 'ELBO': -169.5404}
Test set:
{'KL': 4.4319677, 'recon loss': 165.60225, 'ELBO': -170.03421, 'LL': -168.92137}

##### Hyper-spherical VAE #####
0 {'KL': 0.31487006, 'recon loss': 540.8928, 'ELBO': -541.2077}
100 {'KL': 1.6287001, 'recon loss': 200.90146, 'ELBO': -202.53015}
200 {'KL': 2.736233, 'recon loss': 184.6564, 'ELBO': -187.39264