# Variational Autoencoder with MNIST and FashionMNIST

We will use Mnist and the Zalando FashionMNIST again, to train a variational autoencoder with will also be able to generate new cloth.

As a source for this notebook, see [https://blog.keras.io/building-autoencoders-in-keras.html]. A another accurate example can be found here: [https://towardsdatascience.com/teaching-a-variational-autoencoder-vae-to-draw-mnist-characters-978675c95776] .

To begin, we need to load some python modules including common layers from keras.

In [None]:
'''
  Variational Autoencoder (VAE) with the Keras Functional API.
'''

import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt

# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()

# Data & model configuration
img_width, img_height = input_train.shape[1], input_train.shape[2]
batch_size = 128
no_epochs = 100
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1

# Reshape data
input_train = input_train.reshape(input_train.shape[0], img_height, img_width, num_channels)
input_test = input_test.reshape(input_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')

# Normalize data
input_train = input_train / 255
input_test = input_test / 255

# # =================
# # Encoder
# # =================

# Definition
i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)

# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)

# Define sampling with reparameterization trick
def sample_z(args):
  mu, sigma = args
  batch     = K.shape(mu)[0]
  dim       = K.int_shape(mu)[1]
  eps       = K.random_normal(shape=(batch, dim))
  return mu + K.exp(sigma / 2) * eps

# Use reparameterization trick to ....??
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])

# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()

# =================
# Decoder
# =================

# Definition
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same',  activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)

# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()

# =================
# VAE as a whole
# =================

# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')
vae.summary()

# Define loss
def kl_reconstruction_loss(true, pred):
  # Reconstruction loss
  reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
  # KL divergence loss
  kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
  kl_loss = K.sum(kl_loss, axis=-1)
  kl_loss *= -0.5
  # Total loss = 50% rec + 50% KL divergence loss
  return K.mean(reconstruction_loss + kl_loss)

# Compile VAE
vae.compile(optimizer='adam', loss=kl_reconstruction_loss)

# Train autoencoder
vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)

# =================
# Results visualization
# Credits for original visualization code: https://keras.io/examples/variational_autoencoder_deconv/
# (François Chollet).
# Adapted to accomodate this VAE.
# =================
def viz_latent_space(encoder, data):
  input_data, target_data = data
  mu, _, _ = encoder.predict(input_data)
  plt.figure(figsize=(8, 10))
  plt.scatter(mu[:, 0], mu[:, 1], c=target_data)
  plt.xlabel('z - dim 1')
  plt.ylabel('z - dim 2')
  plt.colorbar()
  plt.show()

def viz_decoded(encoder, decoder, data):
  num_samples = 15
  figure = np.zeros((img_width * num_samples, img_height * num_samples, num_channels))
  grid_x = np.linspace(-4, 4, num_samples)
  grid_y = np.linspace(-4, 4, num_samples)[::-1]
  for i, yi in enumerate(grid_y):
      for j, xi in enumerate(grid_x):
          z_sample = np.array([[xi, yi]])
          x_decoded = decoder.predict(z_sample)
          digit = x_decoded[0].reshape(img_width, img_height, num_channels)
          figure[i * img_width: (i + 1) * img_width,
                  j * img_height: (j + 1) * img_height] = digit
  plt.figure(figsize=(10, 10))
  start_range = img_width // 2
  end_range = num_samples * img_width + start_range + 1
  pixel_range = np.arange(start_range, end_range, img_width)
  sample_range_x = np.round(grid_x, 1)
  sample_range_y = np.round(grid_y, 1)
  plt.xticks(pixel_range, sample_range_x)
  plt.yticks(pixel_range, sample_range_y)
  plt.xlabel('z - dim 1')
  plt.ylabel('z - dim 2')
  # matplotlib.pyplot.imshow() needs a 2D array, or a 3D array with the third dimension being of shape 3 or 4!
  # So reshape if necessary
  fig_shape = np.shape(figure)
  if fig_shape[2] == 1:
    figure = figure.reshape((fig_shape[0], fig_shape[1]))
  # Show image
  plt.imshow(figure)
  plt.show()

# Plot results
data = (input_test, target_test)
viz_latent_space(encoder, data)
viz_decoded(encoder, decoder, data)

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 8)    80          encoder_input[0][0]              
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 14, 14, 8)    32          conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 7, 7, 16)     1168        batch_normalization_1[0][0]      
____________________________________________________________________________________________

In [2]:
# MNIST dataset
from keras.datasets import mnist
import tensorflow as tf

# numpy and pyplot
import numpy as np
import matplotlib.pyplot as plt

# keras
import keras
from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, MaxPooling2D, UpSampling2D, Dropout, BatchNormalization
from keras.layers import Multiply, Add, GaussianNoise, Lambda
from keras.models import Model
from keras.losses import binary_crossentropy
import keras.backend as K

Using TensorFlow backend.


We prepare the data by normalizing it.

Sincle we are doing unsupervised learning here, we will not need the labels provided by the dataset for now. We keep them however, as they will help with visualizing the results later.

There are 60k training and 10k test examples.

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

fashion_mnist = keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()


x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In priciple, an autoencoder consists of two models: the encoder and the decoder. To represent this, we are using keras' functional API where we can easily define models from component models.

We start by defining the encoder, whose output will no be the latent vector, but the mean and log of the standard deviation of the latent representations.

The next submodel is on that sample from the distribution generated by the encoder.

Then we define the decoder, which takes the sampled latent vector as input and produces full-size images again.

Finally, we concatenate everything to get our variational autoencoder.

In [4]:
def makeVAE(encodingDim=32):
    # this is our input placeholder
    inputImg = Input(shape=x_train.shape[1:])
    x = Reshape((*inputImg.shape.as_list()[1:],1))(inputImg)
    # encoder
    x = Conv2D(16, kernel_size=(5,5), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = MaxPooling2D(pool_size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    lastConvShape = x.shape.as_list()[1:]
    
    x = Flatten()(x)    
    x = Dense(encodingDim*4, activation='relu')(x)
    x = Dense(encodingDim*2, activation='relu')(x)
    x = Dense(encodingDim, activation='relu')(x)
    
    mean = Dense(encodingDim)(x)
    logstdev = Dense(encodingDim)(x)
    
    encoder = Model(inputImg, [mean, logstdev], name="encoder")
    encoder.summary()

    def sampling(args):
        mean, logstdev = args
        eps = K.random_normal(shape=(K.shape(logstdev)[0], encodingDim))
        return mean + K.exp(logstdev) * eps

    meanS = Input(shape=(encodingDim,))
    logstdevS = Input(shape=(encodingDim,))
    x = Lambda(sampling)([meanS, logstdevS])
    sample = Model([meanS, logstdevS], x, name="sample")
    sample.summary()
    
    # this is our latent space placeholder
    inputLat = Input(shape=(encodingDim,))
    #decoder
    x = Dense(encodingDim*4, activation='relu')(inputLat)
    x = Dense(np.prod(lastConvShape), activation='relu')(x)
    
    x = Reshape(lastConvShape)(x) # remove channel dimension
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(128, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = UpSampling2D(size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = Conv2D(32, kernel_size=(1,1), activation='relu', padding="same")(x)
    x = Conv2D(64, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = UpSampling2D(size=(2,2))(x)
    x = BatchNormalization()(x)
    x = Conv2D(32, kernel_size=(3,3), activation='relu', padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(1, kernel_size=(5,5), activation='sigmoid', padding="same")(x)
    x = Reshape(x.shape.as_list()[1:3])(x) # remove channel dimension

    
    decoder = Model(inputLat, x, name="decoder")
    decoder.summary()

    # this model maps an input to its reconstruction
    autoencoder = Model(inputImg, decoder(sample(encoder(inputImg))), name="vae")
    autoencoder.summary()
    
    def loss(x, output):
        recon_loss = K.sum(binary_crossentropy(x, output))
        """
        This is quite dirty: using the layer handles from the definition of the encoder to calculate the loss.
        It would be better to have these values as additional outputs of the network,
        but keras does not allow passing multiple outputs into a single loss function.
        """
        kl_loss = - 0.5 * K.mean(1. + 2.*logstdev - K.square(mean) - K.exp(2.*logstdev), axis=-1)
        return recon_loss + kl_loss
        #return kl_loss
    
    return encoder, decoder, autoencoder, sample, loss

In [5]:
encoder, decoder, autoencoder, sample, loss = makeVAE(16)
opt = keras.optimizers.Adam(lr=0.001)
autoencoder.compile(optimizer=opt, loss=loss)

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 28, 28)       0                                            
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 28, 28, 1)    0           input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 16)   416         reshape_1[0][0]                  
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 28, 28, 32)   4640        conv2d_1[0][0]                   
____________________________________________________________________________________________

ValueError: Duplicate node name in graph: 'lambda_1/random_normal/shape'

Here we passed our custom loss function when compiling the model. Next, we will fit.

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=40,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))

Here we define a function to compare original and reconstructed images, which we will use later.

In [None]:
def showImages(ae, data):
    decoded_imgs = ae.predict(data)

    n = data.shape[0]  # how many cloth we will display
    height = 20
    plt.figure(figsize=(height, height/n*2))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(data[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
showImages(autoencoder, x_test[0:20])

We can also sample the latent vectors, which should follow a unit gaussian to generate new cloth.

In [None]:
def showImagesGen(decoder, sample=sample, n=20):
    latentDim = sample.inputs[0].shape.as_list()[-1]
    mean = np.array([0.]*latentDim*n).reshape([n,latentDim])
    stdev = np.array([1.]*latentDim*n).reshape([n,latentDim])
    decoded_imgs = decoder.predict(sample.predict([mean, stdev]))

    height = 20
    plt.figure(figsize=(height, height/n))
    for i in range(n):
        # display reconstruction
        ax = plt.subplot(1, n, i + 1)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [None]:
showImagesGen(decoder)

In [None]:
import matplotlib.pyplot as plt


def plot_latent(encoder, decoder):
    # display a n*n 2D manifold of digits
    n = 30
    digit_size = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent(encoder, decoder)

### PLEASE RUN THIS COMMAND IF YOU FINISHED THE NOTEBOOK

In [None]:
import os
temp=os.getpid()
!kill -9 $temp