In [9]:
# loading libraries for data manipulation
import numpy as np
import pandas as pd

# loading libraries for data visualization
import matplotlib.pyplot as plt

# import tensorflow and keras packages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


We will use the MNIST data to create a Convolutional Variational Autoencoder, that can represent images of digits into latent distributions. 

We will first define a custom layer Sampling that will be used to sample a vector from the latent space. Recall that in theory, the encoder learns the mean and std dev vectors for hidden representation. In implementation, we use the log of the variance vector instead of the std dev vector. This avoids negative variance problems and adds stability to gradients.

In [2]:
# class uses mean and log variance vectors + noise to get a sample z
# there is only one function in the class that will generate a sample
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [3]:
# our z vector (latent representation) is going to be 2-dimensional
# this helps with visualization
latent_dim = 2

Let's define our Encoder. We will take each image as input, apply some Convolution layers, flatten, and then create our mean and log_var vectors that represent our latent space. 

In [None]:
encoder_inputs = keras.Input(shape=(28, 28, 1)) # 28x28x1 image

# convolution layers - note no max pooling, instead stride is used to downsample
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)

# flatten the network and create a fully connected layer with 16 units
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)

# connect the 6-unit layer to mean and log_var vectors separately
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)

# z is then sampled from this distribution
z = Sampling()([z_mean, z_log_var])

# defining our encoder with the correct input and output
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Similarly, our Decoder will take the z vector and use it to reconstruct the input image. 

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,)) # input to the decoder

# use Conv2DTranspose to upsample 
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)

# define the decoder with its input and output
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Because VAEs require a custom loss function: reconstruction loss + KL loss, we will create a VAE class where we explicitly define our loss and training mechanism.

In [6]:
# VAE class
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss") # we will track 3 loss values
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    # this function defines how training occurs
    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data) # get mean and log_var vector from encoder
            reconstruction = self.decoder(z) # reconstruct using decoder
            reconstruction_loss = tf.reduce_mean( # reconstruction loss is binary cross entropy
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            # KL regularization error
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss # compute total loss

        # explicitly calculate gradients    
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) # back prop
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


Now we can train our VAE on MNIST digits data

In [None]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=10, batch_size=128, verbose = 1)

Now we can sample a random noise vector from Standard Normal Distribution and ask the Decoder to construct the digit image associated with that noise vector in the latent space. 

In [None]:
# latent_dim = 2 in your model
random_latent_vector = np.random.normal(size=(1, latent_dim))
print(random_latent_vector)

generated_digit = decoder.predict(random_latent_vector)

plt.imshow(generated_digit[0, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()

To get a good sense of the learned latent space, we can sample points in that space and ask the Decoder to generate digit images for them

In [None]:
# function will display a two-dimensional manifold of digits
def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose = 0)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(vae)

We can also perform interpolation between two digits to observe how a continuous smooth translation is formed between them in the latent space

In [None]:
# pick two random digits from MNIST
digit1 = mnist_digits[0:1]  # shape (1,28,28,1)
digit2 = mnist_digits[1:2]

# encode them to get latent vectors
z_mean1, z_log_var1, z1 = encoder.predict(digit1)
z_mean2, z_log_var2, z2 = encoder.predict(digit2)


num_steps = 10 # intermediate images
interpolated_latents = []

# interpolate between the two digits
for alpha in np.linspace(0, 1, num_steps):
    z_interp = (1 - alpha) * z1 + alpha * z2
    interpolated_latents.append(z_interp)

interpolated_latents = np.concatenate(interpolated_latents, axis=0) # update shape

# generate digit images
interpolated_digits = decoder.predict(interpolated_latents)

# plot them all in a line
plt.figure(figsize=(20, 2))
for i in range(num_steps):
    plt.subplot(1, num_steps, i+1)
    plt.imshow(interpolated_digits[i, :, :, 0], cmap='gray')
    plt.axis('off')
plt.show()