### Week 9: Normalising flows

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import numpy as np
import matplotlib.pyplot as plt

## Tensorflow bijectors

### Base distribution

In [None]:
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], tf.float32), scale_diag=tf.constant([1, 1], tf.float32))

In [None]:
SAMPLE_BATCH_SIZE = 512
tf.set_random_seed(1000)

In [None]:
z = base_dist.sample(SAMPLE_BATCH_SIZE)
print(z)

In [None]:
sess = tf.InteractiveSession()

In [None]:
z_samples = z.eval()
print(type(z_samples))
print(z_samples.shape)

In [None]:
fig = plt.figure(figsize=(5, 5))
plt.scatter(z_samples[:, 0], z_samples[:, 1], s=10)
plt.title("Base distribution: standard normal")
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.show()

### Transform the distribution

A Bijector is used to transform distributions. Bijectors are the building blocks for a normalising flow. 
They are characterised by the following three main methods:
    1. forward
    2. inverse
    3. log_det_jacobian

Conventionally, think of the `forward` operation as acting on the base distribution (generate samples) and the `inverse` operation is used to calculate probabilities.

For example, the Affine Bijector:

In [None]:
affine_bijector = tfb.Affine(shift=[1., -1.], scale_diag=[0.5, 1.5])

In [None]:
fwd_z = affine_bijector.forward(z)

In [None]:
z_samples, x_samples = sess.run([z, fwd_z])

In [None]:
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax.scatter(z_samples[:, 0], z_samples[:, 1], s=10)
ax.set_title("Base distribution: standard normal")
ax.set_xlim([-5, 5])
ax.set_ylim([-5, 5])

ax2.scatter(x_samples[:, 0], x_samples[:, 1], s=10, color='r')
ax2.set_title("Transformed distribution: shift [1, -1], scale [0.5, 1.5]")
ax2.set_xlim([-5, 5])
ax2.set_ylim([-5, 5])
plt.show()

In [None]:
fwd_inv_z = affine_bijector.inverse(fwd_z)

In [None]:
latents = np.random.random((SAMPLE_BATCH_SIZE, 2))
print(np.allclose(latents, sess.run(fwd_inv_z, feed_dict={z: latents})))

### Computing probabilities

In [None]:
x = tf.placeholder(shape=(1, 2), dtype=tf.float32)

log_det_dzdx = affine_bijector.inverse_log_det_jacobian(x, event_ndims=1)
log_det_dzdx

In [None]:
inv_x = affine_bijector.inverse(x)
inv_x

In [None]:
log_prob_inv_x = base_dist.log_prob(inv_x)
log_prob_inv_x

In [None]:
x_fixed_sample = np.array([[1., -1.]])  # Mode of the transformed distribution

sess.run(log_det_dzdx, feed_dict={x: x_fixed_sample})

Check: Jacobian determinant is just the product of scaling factors

In [None]:
- np.log(0.5) - np.log(1.5)

Calculate log probability of `x`:

In [None]:
sess.run(log_prob_inv_x + log_det_dzdx, feed_dict={x: np.array([[1., -1.]])})

Check:

In [None]:
np.log(np.sqrt(1 / (2 * np.pi)**2)) - np.log(0.5) - np.log(1.5)

## Learned flow example

### Target distribution

In [None]:
x2_dist = tfd.Normal(loc=0., scale=4.)
x2 = x2_dist.sample(SAMPLE_BATCH_SIZE)
x1_dist = tfd.Normal(loc=.25 * tf.square(x2), scale=tf.ones(SAMPLE_BATCH_SIZE, dtype=tf.float32))
x1 = x1_dist.sample()
x = tf.stack([x1, x2], axis=1)

In [None]:
np_samples = sess.run(x)
plt.scatter(np_samples[:, 0], np_samples[:, 1], s=10)
plt.xlim([-5, 30])
plt.ylim([-10, 10])
plt.title("Target distribution")
plt.show()

### Set up the normalising flow

In [None]:
class LeakyReLU(tfb.Bijector):
    def __init__(self, alpha=0.5, validate_args=False, name="leaky_relu"):
        super().__init__(forward_min_event_ndims=1, validate_args=validate_args, name=name)
        self.alpha = alpha

    def _forward(self, x):
        return tf.where(tf.greater_equal(x, 0), x, self.alpha * x)

    def _inverse(self, y):
        return tf.where(tf.greater_equal(y, 0), y, 1. / self.alpha * y)

    def _inverse_log_det_jacobian(self, y):
        I = tf.ones_like(y)
        J_inv = tf.where(tf.greater_equal(y, 0), I, 1.0 / self.alpha * I)
        log_abs_det_J_inv = tf.log(tf.abs(J_inv))
        return tf.reduce_sum(log_abs_det_J_inv, axis=-1)

In [None]:
# Stores the Bijector layers that will make up our normalising flow
bijectors_list = []
num_layers = 6
d, r = 2, 2

for i in range(num_layers):
    with tf.variable_scope('bijector_%d' % i):
        V = tf.get_variable('V', [d, r], dtype=tf.float32)  # factor loading
        shift = tf.get_variable('shift', [d], dtype=tf.float32)  # affine shift
        L = tf.get_variable('L', [d * (d + 1) / 2],
                            dtype=tf.float32)  # lower triangular
        bijectors_list.append(tfb.Affine(
            scale_tril=tfd.fill_triangular(L),
            scale_perturb_factor=V,
            shift=shift,
            name="affine_{}".format(i)
        ))
        
        if i != num_layers - 1:
            alpha = tf.abs(tf.get_variable('alpha', [], dtype=tf.float32)) + .01
            bijectors_list.append(LeakyReLU(alpha=alpha, name="leaky_relu_{}".format(i)))

Build the network from the list of Bijectors.

Note that `tfb.Chain` processes a list of bijectors in the reverse order

In [None]:
mlp_bijector = tfb.Chain(list(reversed(bijectors_list)), name='mlp_bijector')

A `TransformedDistribution` takes a base distribution and a bijector (the `Chain` above) to define the flow. 

We will use the same base distribution defined earlier (multivariate normal).

In [None]:
flow = tfd.TransformedDistribution(distribution=base_dist, bijector=mlp_bijector)

### Visualise the layer transformations before training

In [None]:
z = base_dist.sample(SAMPLE_BATCH_SIZE)
samples = [z]  # Collect samples for each layer in the network
names = [base_dist.name]  # Collect the names for each layer in the network
h = z
for bijector in reversed(flow.bijector.bijectors):
    h = bijector.forward(h)
    samples.append(h)
    names.append(bijector.name)

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
layer_samples = sess.run(samples)
num_layers = len(layer_samples)  # 12
num_plot_rows = 3
f, arr = plt.subplots(num_plot_rows, num_layers // num_plot_rows, figsize=(15, num_plot_rows * 3))

for i in range(num_plot_rows):
    for j in range(num_layers // num_plot_rows):
        layer_num = (i * (num_layers // num_plot_rows)) + j
        
        X = layer_samples[layer_num]
        arr[i, j].scatter(X[:, 0], X[:, 1], s=8)
        arr[i, j].set_xlim([-5, 5])
        arr[i, j].set_ylim([-5, 5])
        arr[i, j].set_title(names[layer_num])
plt.show()

### Optimise the network

In [None]:
loss = -tf.reduce_mean(flow.log_prob(x))
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

In [None]:
# # Train a new model

# import time

# # Need to initialise the optimizer variables
# sess.run(tf.global_variables_initializer())

# saver = tf.train.Saver()
# NUM_STEPS = int(50000)
# global_step = []
# np_losses = []
# start_time = time.time()
# for i in range(NUM_STEPS):
#     _, np_loss = sess.run([train_op, loss])
#     if i % 1000 == 0:
#         global_step.append(i)
#         np_losses.append(np_loss)
#     if i % int(1e4) == 0:
#         print(i, np_loss)
# end_time = time.time()
# saver.save(sess, './', global_step=global_step[-1])
# np.save('./np_losses.npy', np_losses)
# np.save('./global_step.npy', global_step)
# print("Training time: {}".format(end_time - start_time))

In [None]:
# Load a previously saved model

saver = tf.train.Saver()
meta_graph = tf.train.latest_checkpoint('./')
saver.restore(sess, meta_graph)
np_losses = np.load('./np_losses.npy')
global_step = np.load('./global_step.npy')

In [None]:
start = 10
plt.plot(global_step[start:], np_losses[start:])
plt.ylabel("Loss")
plt.xlabel("Iteration")
plt.show()

In [None]:
layer_samples = sess.run(samples)
num_plot_rows = 3
num_layers = len(layer_samples)  # 12
f, arr = plt.subplots(num_plot_rows, num_layers // num_plot_rows, figsize=(15, num_plot_rows * 3))

for i in range(num_plot_rows):
    for j in range(num_layers // num_plot_rows):
        layer_num = (i * (num_layers // num_plot_rows)) + j
        
        X = layer_samples[layer_num]
        arr[i, j].scatter(X[:, 0], X[:, 1], s=8)
        arr[i, j].set_xlim([-5, 20])
        arr[i, j].set_ylim([-10, 10])
        arr[i, j].set_title(names[layer_num])

plt.show()

In [None]:
samples = [z, mlp_bijector.forward(z)]
layer_samples = sess.run(samples)
f, arr = plt.subplots(1, 2, figsize=(15, 6))

X = layer_samples[0]
arr[0].scatter(X[:, 0], X[:, 1], s=8)
arr[0].set_xlim([-5, 5])
arr[0].set_ylim([-5, 5])
arr[0].set_title('Base distribution')
X = layer_samples[1]
arr[1].scatter(X[:, 0], X[:, 1], s=8, color='red')
arr[1].set_xlim([-5, 30])
arr[1].set_ylim([-10, 10])
arr[1].set_title('Transformed distribution')

plt.show()

In [None]:
sess.close()