# RealNVP on MNIST

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/salu133445/flows/blob/main/realnvp_mnist.ipynb)

Code adapted from https://github.com/LukasRinder/normalizing-flows.

## Imports

In [1]:
import os
import sys
import time

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

tfb = tfp.bijectors
tfd = tfp.distributions
tf.keras.backend.set_floatx("float32")

## Functions

In [2]:
def logit(z, beta=10e-6):
    """
    Conversion to logit space according to equation (24) in [Papamakarios et
    al. (2017)]. Includes scaling the input image to [0, 1] and conversion to
    logit space.

    :param z: Input tensor, e.g. image. Type: tf.float32.
    :param beta: Small value. Default: 10e-6.
    :return: Input tensor in logit space.
    """
    inter = beta + (1 - 2 * beta) * (z / 256)
    return tf.math.log(inter / (1 - inter))  # logit function


def inverse_logit(x, beta=10e-6):
    """
    Reverts the preprocessing steps and conversion to logit space and outputs
    an image in range [0, 256]. Inverse of equation (24) in [Papamakarios et
    al. (2017)].
    
    :param x: Input tensor in logit space. Type: tf.float32.
    :param beta: Small value. Default: 10e-6.
    :return: Input tensor in logit space.
    """
    x = tf.math.sigmoid(x)
    return (x - beta) * 256 / (1 - 2 * beta)

def load_and_preprocess_mnist(
    logit_space=True, batch_size=128, shuffle=True, classes=-1, channels=False
):
    """
    Loads and preprocesses the MNIST dataset. Train set: 50000, val set: 10000,
    test set: 10000.

    :param logit_space: If True, the data is converted to logit space.
    :param batch_size: batch size
    :param shuffle: bool. If True, dataset will be shuffled.
    :param classes: int of class to take, defaults to -1 = ALL
    :return: Three batched TensorFlow datasets:
      batched_train_data, batched_val_data, batched_test_data.
    """
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # reserve last 10000 training samples as validation set
    x_train, x_val = x_train[:-10000], x_train[-10000:]
    y_train, y_val = y_train[:-10000], y_train[-10000:]

    # if logit_space: convert to logit space, else: scale to [0, 1]
    if logit_space:
        x_train = logit(tf.cast(x_train, tf.float32))
        x_test = logit(tf.cast(x_test, tf.float32))
        x_val = logit(tf.cast(x_val, tf.float32))
        interval = 256
    else:
        x_train = tf.cast(x_train / 256, tf.float32)
        x_test = tf.cast(x_test / 256, tf.float32)
        x_val = tf.cast(x_val / 256, tf.float32)
        interval = 1

    if classes == -1:
        pass
    else:
        # TODO: Extract Multiple classes: How to to the train,val split,
        # Do we need to to a class balance???
        x_train = np.take(x_train, tf.where(y_train == classes), axis=0)
        x_val = np.take(x_val, tf.where(y_val == classes), axis=0)
        x_test = np.take(x_test, tf.where(y_test == classes), axis=0)

    # reshape if necessary
    if channels:
        x_train = tf.reshape(x_train, (x_train.shape[0], 28, 28, 1))
        x_val = tf.reshape(x_val, (x_val.shape[0], 28, 28, 1))
        x_test = tf.reshape(x_test, (x_test.shape[0], 28, 28, 1))
    else:
        x_train = tf.reshape(x_train, (x_train.shape[0], 28, 28))
        x_val = tf.reshape(x_val, (x_val.shape[0], 28, 28))
        x_test = tf.reshape(x_test, (x_test.shape[0], 28, 28))

    if shuffle:
        shuffled_train_data = tf.data.Dataset.from_tensor_slices(
            x_train
        ).shuffle(1000)

    batched_train_data = shuffled_train_data.batch(batch_size)
    batched_val_data = tf.data.Dataset.from_tensor_slices(x_val).batch(
        batch_size
    )
    batched_test_data = tf.data.Dataset.from_tensor_slices(x_test).batch(
        batch_size
    )

    return batched_train_data, batched_val_data, batched_test_data, interval

In [3]:
@tf.function
def nll(distribution, data):
    """
    Computes the negative log liklihood loss for a given distribution and given
    data.
    
    :param distribution: TensorFlow distribution, e.g.
      tf.TransformedDistribution.
    :param data: Data or a batch from data.
    :return: Negative Log Likelihodd loss.
    """
    return -tf.reduce_mean(distribution.log_prob(data))

@tf.function
def train_density_estimation(distribution, optimizer, batch):
    """
    Train function for density estimation normalizing flows.
    
    :param distribution: TensorFlow distribution, e.g.
      tf.TransformedDistribution.
    :param optimizer: TensorFlow keras optimizer, e.g.
      tf.keras.optimizers.Adam.
    :param batch: Batch of the train data.
    :return: loss.
    """
    with tf.GradientTape() as tape:
        tape.watch(distribution.trainable_variables)
        loss = -tf.reduce_mean(
            distribution.log_prob(batch)
        )  # negative log likelihood
    gradients = tape.gradient(loss, distribution.trainable_variables)
    optimizer.apply_gradients(zip(gradients, distribution.trainable_variables))

    return loss

In [4]:
class NN(tf.keras.layers.Layer):
    """
    Neural Network Architecture for calcualting s and t for Real-NVP

    :param input_shape: shape of the data coming in the layer
    :param hidden_units: Python list-like of non-negative integers, specifying
      the number of units in each hidden layer.
    :param activation: Activation of the hidden units
    """

    def __init__(
        self, input_shape, n_hidden=[512, 512], activation="relu", name="nn"
    ):
        super().__init__(name=name)
        layer_list = []
        for i, hidden in enumerate(n_hidden):
            layer_list.append(
                tf.keras.layers.Dense(hidden, activation=activation)
            )
        self.layer_list = layer_list
        self.log_s_layer = tf.keras.layers.Dense(
            input_shape, activation="tanh", name="log_s")
        self.t_layer = tf.keras.layers.Dense(input_shape, name="t")

    def call(self, x):
        y = x
        for layer in self.layer_list:
            y = layer(y)
        log_s = self.log_s_layer(y)
        t = self.t_layer(y)
        return log_s, t


class RealNVP(tfb.Bijector):
    """
    Implementation of a Real-NVP for Denisty Estimation. L. Dinh “Density
    estimation using Real NVP,” 2016. This implementation only works for 1D
    arrays.

    :param input_shape: shape of the data coming in the layer
    :param hidden_units: Python list-like of non-negative integers, specifying
      the number of units in each hidden layer.
    """

    def __init__(
        self,
        input_shape,
        n_hidden=[512, 512],
        forward_min_event_ndims=1,
        validate_args: bool = False,
        name="real_nvp",
    ):
        super().__init__(
            validate_args=validate_args,
            forward_min_event_ndims=forward_min_event_ndims,
            name=name,
        )

        assert input_shape % 2 == 0
        input_shape = input_shape // 2
        nn_layer = NN(input_shape, n_hidden)
        x = tf.keras.Input(input_shape)
        log_s, t = nn_layer(x)
        self.nn = tf.keras.Model(x, [log_s, t], name="nn")

    def _bijector_fn(self, x):
        log_s, t = self.nn(x)
        return tfb.AffineScalar(shift=t, log_scale=log_s)

    def _forward(self, x):
        x_a, x_b = tf.split(x, 2, axis=-1)
        y_b = x_b
        y_a = self._bijector_fn(x_b).forward(x_a)
        y = tf.concat([y_a, y_b], axis=-1)
        return y

    def _inverse(self, y):
        y_a, y_b = tf.split(y, 2, axis=-1)
        x_b = y_b
        x_a = self._bijector_fn(y_b).inverse(y_a)
        x = tf.concat([x_a, x_b], axis=-1)
        return x

    def _forward_log_det_jacobian(self, x):
        x_a, x_b = tf.split(x, 2, axis=-1)
        return self._bijector_fn(x_b).forward_log_det_jacobian(
            x_a, event_ndims=1
        )

    def _inverse_log_det_jacobian(self, y):
        y_a, y_b = tf.split(y, 2, axis=-1)
        return self._bijector_fn(y_b).inverse_log_det_jacobian(
            y_a, event_ndims=1
        )

## Load and process data

In [5]:
category = 2
train_data, val_data, test_data, _ = load_and_preprocess_mnist(
    logit_space=True, batch_size=128, shuffle=True, classes=category
)

In [6]:
category = "all"
train_data, val_data, test_data, _ = load_and_preprocess_mnist(
    logit_space=True, batch_size=128, shuffle=True
)

In [7]:
mnist_shape = (28, 28, 1)
size = 28
input_shape = size*size
permutation = tf.cast(
    np.concatenate(
        (np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))
    ),
    tf.int32
)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(input_shape, tf.float32))

## Configuration

In [8]:
n_images = 10
dataset = "mnist"
exp_number = 1
max_epochs = 200
layers = 5
shape = [256, 256]
base_lr = 1e-4
end_lr = 1e-5

## Build model

In [9]:
bijectors = []
alpha = 1e-3

for i in range(layers):
    bijectors.append(tfb.BatchNormalization())
    bijectors.append(RealNVP(input_shape=input_shape, n_hidden=shape))
    bijectors.append(tfb.Permute(permutation))
    
bijectors.append(
    tfb.Reshape(event_shape_out=(size, size), event_shape_in=(size * size,))
)

bijector = tfb.Chain(
    bijectors=list(reversed(bijectors)), name='chain_of_real_nvp'
)

flow = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
)

# number of trainable variables
n_trainable_variables = len(flow.trainable_variables)

In [10]:
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
    base_lr, max_epochs, end_lr, power=0.5)

checkpoint_directory = "{}/tmp_{}_{}_{}_{}_{}".format(
    dataset, layers, shape[0], shape[1], exp_number, category
)
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
checkpoint = tf.train.Checkpoint(optimizer=opt, model=flow)

## Training

In [11]:
global_step = []
train_losses = []
val_losses = []
# high value to ensure that first loss < min_loss
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 50  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):

    train_data.shuffle(buffer_size=1000)
    batch_train_losses = []
    for batch in train_data:
        batch_loss = train_density_estimation(flow, opt, batch)
        batch_train_losses.append(batch_loss)

    train_loss = tf.reduce_mean(batch_train_losses)

    if i % int(1) == 0:
        batch_val_losses = []
        for batch in val_data:
            batch_loss = nll(flow, batch)
            batch_val_losses.append(batch_loss)

        val_loss = tf.reduce_mean(batch_val_losses)

        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        if i % 10 == 0:
            print(
                f"{i:3d}, train_loss: {train_loss:12.6f}, "
                f"val_loss: {val_loss:12.6f}"
            )

        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i
            checkpoint.write(file_prefix=checkpoint_prefix)

        # no decrease in min_val_loss for "delta_stop epochs"
        elif i - min_val_epoch > delta_stop:  
            break

train_time = time.time() - t_start

print(f"Training time: {train_time}")
print(f"Min val loss: {min_val_loss} at epoch: {min_val_epoch}")
print(f"Last val loss: {val_loss} at epoch: {i}")
print(f"Min train loss: {min_train_loss} at epoch: {min_train_epoch}")
print(f"Last train loss: {train_loss} at epoch: {i}")

Instructions for updating:
`AffineScalar` bijector is deprecated; please use `tfb.Shift(loc)(tfb.Scale(...))` instead.




 0, train_loss:   934.622498, val_loss:   633.685547
10, train_loss:  -301.913330, val_loss:  -241.556763
20, train_loss:  -531.578247, val_loss:  -461.721375
30, train_loss:  -683.400574, val_loss:  -583.579407
40, train_loss:  -807.115479, val_loss:  -682.403748
50, train_loss:  -912.496155, val_loss:  -804.089294
60, train_loss: -1011.752502, val_loss:  -875.467468
70, train_loss: -1085.121948, val_loss:  -965.658447
80, train_loss: -1165.078491, val_loss: -1045.054932
90, train_loss: -1241.736328, val_loss: -1112.344116
100, train_loss: -1313.417969, val_loss: -1167.990479
110, train_loss: -1383.647217, val_loss: -1217.543457
120, train_loss: -1449.916382, val_loss: -1257.981445
130, train_loss: -1505.733887, val_loss: -1302.328247
140, train_loss: -1574.562988, val_loss: -1406.439575
150, train_loss: -1633.193359, val_loss: -1465.922119
160, train_loss: -1696.783569, val_loss: -1514.354980
170, train_loss: -1749.629272, val_loss: -1565.940430
180, train_loss: -1790.681763, val_los

## Test

In [13]:
# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()

test_losses = []
for batch in test_data:
    batch_loss = nll(flow, batch)
    test_losses.append(batch_loss)

test_loss = tf.reduce_mean(test_losses)

test_time = time.time() - t_start

save_dir = "{}/sampling_{}_{}_{}_{}/".format(
    dataset, layers, shape[0], shape[1], category
)

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)
for j in range(n_images):
    plt.figure()
    data = flow.sample(1)
    data = inverse_logit(data)
    data = tf.reshape(data, (1, size, size))
    plt.imshow(data[0], cmap='gray')
    plt.savefig(
        "{}/{}_{}_i{}.png".format(save_dir, exp_number, min_val_epoch, j)
    )
    plt.close()

# remove checkpoint
filelist = [f for f in os.listdir(checkpoint_directory)]
for f in filelist:
    os.remove(os.path.join(checkpoint_directory, f))
os.removedirs(checkpoint_directory)

print(f"Test loss: {test_loss} at epoch: {i}")
print(f"Average test log likelihood: {-test_loss} at epoch: {i}")
print(f"Test time: {test_time}")

Test loss: -1677.56689453125 at epoch: 199
Average test log likelihood: 1677.56689453125 at epoch: 199
Test time: 0.2717106342315674
