# Variational Autoencoder with MNIST and FashionMNIST

We will use Mnist and the Zalando FashionMNIST again, to train a variational autoencoder with will also be able to generate new cloth.

As a source for this notebook, see [https://blog.keras.io/building-autoencoders-in-keras.html]. A another accurate example can be found here: [https://towardsdatascience.com/teaching-a-variational-autoencoder-vae-to-draw-mnist-characters-978675c95776] .

To begin, we need to load some python modules including common layers from keras.

In [5]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

"""
## Create a sampling layer
"""


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


"""
## Build the encoder
"""

latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

"""
## Build the decoder
"""

latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

"""
## Define the VAE as a `Model` with a custom `train_step`
"""


class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }


"""
## Train the VAE
"""

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)

"""
## Display a grid of sampled digits
"""

import matplotlib.pyplot as plt


def plot_latent(encoder, decoder):
    # display a n*n 2D manifold of digits
    n = 30
    digit_size = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent(encoder, decoder)

"""
## Display how the latent space clusters different digit classes
"""


def plot_label_clusters(encoder, decoder, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(encoder, decoder, x_train, y_train)

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 14, 14, 32)   320         input_4[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 7, 64)     18496       conv2d[0][0]                     
__________________________________________________________________________________________________
flatten (Flatten)               (None, 3136)         0           conv2d_1[0][0]                   
____________________________________________________________________________________________

ValueError: The model cannot be compiled because it has no loss to optimize.

In [1]:
# MNIST dataset
from keras.datasets import mnist
import tensorflow as tf

# numpy and pyplot
import numpy as np
import matplotlib.pyplot as plt

# keras
import keras
from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, MaxPooling2D, UpSampling2D, Dropout, BatchNormalization
from keras.layers import Multiply, Add, GaussianNoise, Lambda
from keras.models import Model
from keras.losses import binary_crossentropy
import keras.backend as K

Using TensorFlow backend.


We prepare the data by normalizing it.

Sincle we are doing unsupervised learning here, we will not need the labels provided by the dataset for now. We keep them however, as they will help with visualizing the results later.

There are 60k training and 10k test examples.

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

fashion_mnist = keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()


x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In priciple, an autoencoder consists of two models: the encoder and the decoder. To represent this, we are using keras' functional API where we can easily define models from component models.

We start by defining the encoder, whose output will no be the latent vector, but the mean and log of the standard deviation of the latent representations.

The next submodel is on that sample from the distribution generated by the encoder.

Then we define the decoder, which takes the sampled latent vector as input and produces full-size images again.

Finally, we concatenate everything to get our variational autoencoder.

In [3]:
def makeVAE(encodingDim=32):
    # this is our input placeholder
    inputImg = Input(shape=x_train.shape[1:])
    x = Reshape((*inputImg.shape.as_list()[1:],1))(inputImg)
    # encoder
    x = Conv2D(16, kernel_size=(5,5), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    lastConvShape = x.shape.as_list()[1:]
    
    x = Flatten()(x)    
    x = Dense(encodingDim*4, activation='relu')(x)
    x = Dense(encodingDim*2, activation='relu')(x)
    x = Dense(encodingDim, activation='relu')(x)
    
    mean = Dense(encodingDim)(x)
    logstdev = Dense(encodingDim)(x)
    
    encoder = Model(inputImg, [mean, logstdev], name="encoder")
    encoder.summary()

    def sampling(args):
        mean, logstdev = args
        eps = K.random_normal(shape=(K.shape(logstdev)[0], encodingDim))
        return mean + K.exp(logstdev) * eps

    meanS = Input(shape=(encodingDim,))
    logstdevS = Input(shape=(encodingDim,))
    x = Lambda(sampling)([meanS, logstdevS])
    sample = Model([meanS, logstdevS], x, name="sample")
    sample.summary()
    
    # this is our latent space placeholder
    inputLat = Input(shape=(encodingDim,))
    #decoder
    x = Dense(encodingDim*4, activation='relu')(inputLat)
    x = Dense(np.prod(lastConvShape), activation='relu')(x)
    
    x = Reshape(lastConvShape)(x) # remove channel dimension
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = UpSampling2D(size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = UpSampling2D(size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(32, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(1, kernel_size=(5,5), activation='sigmoid', padding="same")(x)
    x = Reshape(x.shape.as_list()[1:3])(x) # remove channel dimension

    
    decoder = Model(inputLat, x, name="decoder")
    decoder.summary()

    # this model maps an input to its reconstruction
    autoencoder = Model(inputImg, decoder(sample(encoder(inputImg))), name="vae")
    autoencoder.summary()
    
    def loss(x, output):
        recon_loss = K.sum(binary_crossentropy(x, output))
        """
        This is quite dirty: using the layer handles from the definition of the encoder to calculate the loss.
        It would be better to have these values as additional outputs of the network,
        but keras does not allow passing multiple outputs into a single loss function.
        """
        kl_loss = - 0.5 * K.mean(1. + 2.*logstdev - K.square(mean) - K.exp(2.*logstdev), axis=-1)
        return recon_loss + kl_loss
        #return kl_loss
    
    return encoder, decoder, autoencoder, sample, loss

In [4]:
encoder, decoder, autoencoder, sample, loss = makeVAE(16)
opt = keras.optimizers.Adam(lr=0.001)
autoencoder.compile(optimizer=opt, loss=loss)

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28)       0                                            
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 28, 28, 1)    0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 16)   416         reshape_1[0][0]                  
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 28, 28, 32)   4640        conv2d_1[0][0]                   
____________________________________________________________________________________________

ValueError: Duplicate node name in graph: 'lambda_1/random_normal/shape'

Here we passed our custom loss function when compiling the model. Next, we will fit.

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=40,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))

Here we define a function to compare original and reconstructed images, which we will use later.

In [None]:
def showImages(ae, data):
    decoded_imgs = ae.predict(data)

    n = data.shape[0]  # how many cloth we will display
    height = 20
    plt.figure(figsize=(height, height/n*2))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(data[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
showImages(autoencoder, x_test[0:20])

We can also sample the latent vectors, which should follow a unit gaussian to generate new cloth.

In [None]:
def showImagesGen(decoder, sample=sample, n=20):
    latentDim = sample.inputs[0].shape.as_list()[-1]
    mean = np.array([0.]*latentDim*n).reshape([n,latentDim])
    stdev = np.array([1.]*latentDim*n).reshape([n,latentDim])
    decoded_imgs = decoder.predict(sample.predict([mean, stdev]))

    height = 20
    plt.figure(figsize=(height, height/n))
    for i in range(n):
        # display reconstruction
        ax = plt.subplot(1, n, i + 1)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
showImagesGen(decoder)

In [None]:
import matplotlib.pyplot as plt


def plot_latent(encoder, decoder):
    # display a n*n 2D manifold of digits
    n = 30
    digit_size = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent(encoder, decoder)

### PLEASE RUN THIS COMMAND IF YOU FINISHED THE NOTEBOOK

In [None]:
import os
temp=os.getpid()
!kill -9 $temp