### Week 9: Normalising flows pt 3 - improved variational posterior with IAF

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

## Improved variational posterior

In [None]:
tf.set_random_seed(100)

In [None]:
class VAE():

    def __init__(self, use_iaf=False):
        self.sess = tf.Session()
        self.lambda_l2_reg = 0.01
        self.learning_rate = 0.001
        self.dropout = 1.
        self.use_iaf = use_iaf

        handles = self._buildGraph()
        self.sess.run(tf.global_variables_initializer())

        (self.x_in, self.dropout_, self.z_mean, self.z_log_sigma, self.z_sample,
         self.x_reconstructed, self.cost, self.global_step, self.train_op,
         self.rec_loss, self.kl_loss) = handles

    def _buildGraph(self):
        x_in = tf.placeholder(tf.float32, shape=[None, 2], name="x")
        dropout = tf.placeholder_with_default(1., shape=[], name="dropout")

        h = tf.layers.Dense(8, activation=tf.nn.tanh, name="encoding/1")(x_in)
        h = tf.layers.Dense(8, activation=tf.nn.tanh, name="encoding/2")(h)
        
        z_mean = tf.layers.Dense(2, activation=None, name="z_mean")(h)
        z_log_sigma = tf.layers.Dense(2, activation=None, name="z_log_sigma")(h)
        
        z = tfd.MultivariateNormalDiag(loc=z_mean, scale_diag=tf.exp(z_log_sigma))
        
        if not self.use_iaf:
            z_sample = z.sample()
        else: 
            iaf_flow = self.build_iaf_flow(z)
            z_sample = iaf_flow.sample()
        
        h = tf.layers.Dense(8, activation=tf.nn.sigmoid, name="decoding/1")(z_sample)
        h = tf.layers.Dense(8, activation=tf.nn.sigmoid ,name="decoding/2")(h)
        
        x_reconstructed = tf.layers.Dense(2, activation=None, name="decoding/out")(h)
        
        with tf.name_scope("l2_loss"):
            rec_loss = tf.reduce_sum(tf.square(x_reconstructed - x_in), 1)

        if not self.use_iaf:
            kl_loss = VAE.kullbackLeibler(z_mean, z_log_sigma)
        else:
            prior = tfd.MultivariateNormalDiag(loc=tf.zeros([2]))
            kl_loss = iaf_flow.log_prob(z_sample) - tf.log(prior.prob(z_sample) + 1e-10)

        with tf.name_scope("l2_regularization"):
            regularizers = [tf.nn.l2_loss(var) for var in self.sess.graph.get_collection(
                "trainable_variables") if ("kernel" in var.name and "decoding" not in var.name)]
            l2_reg = self.lambda_l2_reg * tf.add_n(regularizers)

        with tf.name_scope("cost"):
            cost = tf.reduce_mean(rec_loss + kl_loss, name="vae_cost")
            cost += l2_reg

        global_step = tf.Variable(0, trainable=False)
        with tf.name_scope("Adam_optimizer"):
            optimizer = tf.train.AdamOptimizer(self.learning_rate)
            tvars = tf.trainable_variables()
            self.grads_and_vars = optimizer.compute_gradients(cost, tvars)
            clipped = [(tf.clip_by_value(grad, -0.1, 0.1), tvar)
                    for grad, tvar in self.grads_and_vars]
            train_op = optimizer.apply_gradients(clipped, global_step=global_step,
                                                 name="minimize_cost")

        return (x_in, dropout, z_mean, z_log_sigma, z_sample, x_reconstructed,
                cost, global_step, train_op, tf.reduce_mean(rec_loss), tf.reduce_mean(kl_loss))

    @staticmethod
    def kullbackLeibler(mu, log_sigma):
        with tf.name_scope("KL_divergence"):
            return -0.5 * tf.reduce_sum(1 + 2 * log_sigma - mu**2 -
                                        tf.exp(2 * log_sigma), 1)
        
    def build_iaf_flow(self, base_dist):
        bijectors = [
            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[64, 64])),
            tfb.Permute(permutation=[1, 0]),
            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[64, 64])),
            tfb.Permute(permutation=[1, 0]),
            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[64, 64])),
            tfb.Permute(permutation=[1, 0]),
            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
            hidden_layers=[64, 64]))
        ]

        maf_bijector = tfb.Chain(list(reversed(bijectors)), name='maf_bijector')
        return tfd.TransformedDistribution(distribution=base_dist, bijector=tfb.Invert(maf_bijector))

    def encode(self, x):
        # Encodes data points to factorised Gaussian, before passing through IAF flow (if used)
        return self.sess.run([self.z_mean, self.z_log_sigma], feed_dict={self.x_in: x})
    
    def posterior_sample(self, x):
        # Samples from the full posterior (after IAF if used)
        return self.sess.run(self.z_sample, feed_dict={self.x_in: x})

    def decode(self, zs):
        return self.sess.run(self.x_reconstructed, feed_dict={self.z_sample: zs})
    
    @staticmethod
    def plot_posterior_distribution(X):
        X1 = X[:64, :]
        X2 = X[64:128, :]
        X3 = X[128:192, :]
        X4 = X[192:, :]
        x1_posterior_samples = model.posterior_sample(X1)
        x2_posterior_samples = model.posterior_sample(X2)
        x3_posterior_samples = model.posterior_sample(X3)
        x4_posterior_samples = model.posterior_sample(X4)
        plt.close()
        plt.figure()
        plt.scatter(x1_posterior_samples[:, 0], x1_posterior_samples[:, 1], color='red', s=5)
        plt.scatter(x2_posterior_samples[:, 0], x2_posterior_samples[:, 1], color='blue', s=5)
        plt.scatter(x3_posterior_samples[:, 0], x3_posterior_samples[:, 1], color='green', s=5)
        plt.scatter(x4_posterior_samples[:, 0], x4_posterior_samples[:, 1], color='purple', s=5)
        plt.title("Posterior distributions")
        display.display(plt.gcf())
        display.clear_output(wait=True)

    def train(self, x, max_iter=np.inf):
        losses = []
        iterations = []
        while True:  
            feed_dict = {self.x_in: x, self.dropout_: self.dropout}
            x_reconstructed, cost, rec_loss, kl_loss, _, i = self.sess.run(
                [self.x_reconstructed, self.cost, self.rec_loss, 
                 self.kl_loss, self.train_op, self.global_step], feed_dict
            )

            if i%500 == 1:
                print("Iteration {}, cost: ".format(i), cost)
                print("   rec_loss: {}, kl_loss: {}".format(rec_loss, kl_loss))
                losses.append(cost)
                iterations.append(i)
                VAE.plot_posterior_distribution(x)

            if i >= max_iter:
                print("Finished training. Final cost at iteration {}: {}".format(i, cost))
                print("   rec_loss: {}, kl_loss: {}".format(rec_loss, kl_loss))
                losses.append(cost)
                iterations.append(i)
                break
        return losses, iterations

In [None]:
x1 = np.array([5., 5.])
x2 = np.array([-5., 5.])
x3 = np.array([-5., -5.])
x4 = np.array([5., -5.])

X1 = np.vstack((x1,) * 64)
X2 = np.vstack((x2,) * 64)
X3 = np.vstack((x3,) * 64)
X4 = np.vstack((x4,) * 64)

X_train = np.vstack((X1, X2, X3, X4))
X_train.shape

In [None]:
model = VAE(use_iaf=False)

In [None]:
import time
start_time = time.time()
losses, iterations = model.train(X_train, max_iter=10000)
end_time = time.time()

print("Training time: {}".format(end_time - start_time))

In [None]:
plt.plot(iterations, losses)
plt.title("Training curve")
plt.show()

In [None]:
# Plot the posterior distributions before passing through the IAF flow

num_samples = 256

x1_mean, x1_log_sigma = model.encode(np.expand_dims(x1, 0))
x2_mean, x2_log_sigma = model.encode(np.expand_dims(x2, 0))
x3_mean, x3_log_sigma = model.encode(np.expand_dims(x3, 0))
x4_mean, x4_log_sigma = model.encode(np.expand_dims(x4, 0))
x1_samples = np.random.normal(loc=np.vstack((x1_mean,) * num_samples), scale=np.vstack((np.exp(x1_log_sigma),) * num_samples))
x2_samples = np.random.normal(loc=np.vstack((x2_mean,) * num_samples), scale=np.vstack((np.exp(x2_log_sigma),) * num_samples))
x3_samples = np.random.normal(loc=np.vstack((x3_mean,) * num_samples), scale=np.vstack((np.exp(x3_log_sigma),) * num_samples))
x4_samples = np.random.normal(loc=np.vstack((x4_mean,) * num_samples), scale=np.vstack((np.exp(x4_log_sigma),) * num_samples))
plt.scatter(x1_samples[:, 0], x1_samples[:, 1], color='red', s=5)
plt.scatter(x2_samples[:, 0], x2_samples[:, 1], color='blue', s=5)
plt.scatter(x3_samples[:, 0], x3_samples[:, 1], color='green', s=5)
plt.scatter(x4_samples[:, 0], x4_samples[:, 1], color='purple', s=5)
plt.title("Posterior distributions before IAF flow")
plt.show()

In [None]:
num_samples = 256

x1_posterior_samples = model.posterior_sample(np.stack([x1] * num_samples))
x2_posterior_samples = model.posterior_sample(np.stack([x2] * num_samples))
x3_posterior_samples = model.posterior_sample(np.stack([x3] * num_samples))
x4_posterior_samples = model.posterior_sample(np.stack([x4] * num_samples))

In [None]:
# Plot the posterior distributions after passing through the IAF flow

plt.scatter(x1_posterior_samples[:, 0], x1_posterior_samples[:, 1], color='red', s=5)
plt.scatter(x2_posterior_samples[:, 0], x2_posterior_samples[:, 1], color='blue', s=5)
plt.scatter(x3_posterior_samples[:, 0], x3_posterior_samples[:, 1], color='green', s=5)
plt.scatter(x4_posterior_samples[:, 0], x4_posterior_samples[:, 1], color='purple', s=5)
plt.title("Posterior distributions after IAF flow")
plt.show()

In [None]:
x1_decoded = model.decode(x1_posterior_samples)
x2_decoded = model.decode(x2_posterior_samples)
x3_decoded = model.decode(x3_posterior_samples)
x4_decoded = model.decode(x4_posterior_samples)
plt.scatter(x1_decoded[:, 0], x1_decoded[:, 1], color='red', s=5)
plt.scatter(x2_decoded[:, 0], x2_decoded[:, 1], color='blue', s=5)
plt.scatter(x3_decoded[:, 0], x3_decoded[:, 1], color='green', s=5)
plt.scatter(x4_decoded[:, 0], x4_decoded[:, 1], color='purple', s=5)
plt.title("Reconstructions of data points")
plt.show()

In [None]:
model.sess.close()