# Imports

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from IPython import display

# Parameters

In [2]:
# Define global constants
BATCH_SIZE = 32
LATENT_DIM = 2

# Prepare the Dataset

In [3]:
def map_image(image,label):
  image = tf.cast(image,tf.float32)
  image = image / 255.0
  image = tf.reshape(image,shape = (28,28,1,))
  return image

def get_dataset(map_fn,is_validation = False):
  if is_validation:
    split_name = 'test'
  else:
    split_name = 'train'

  dataset = tfds.load('mnist',as_supervised = True,split = split_name)
  dataset = dataset.map(map_fn)

  if is_validation:
    dataset = dataset.batch(BATCH_SIZE)
  else:
    dataset = dataset.shuffle(1024).batch(BATCH_SIZE)

  return dataset

In [4]:
train_dataset = get_dataset(map_image)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


# Build the Model

<img src="https://drive.google.com/uc?export=view&id=1YAZAeMGEJ1KgieYk1ju-S9DoshpMREeC" width="60%" height="60%"/>

### Sampling Class

In [22]:
class Sampling(tf.keras.layers.Layer):
  def call(self,inputs):
    # Unpack the output of the encoder
    mu,sigma = inputs

    # Get the size and dimensions of the batch
    batch = tf.shape(mu)[0]
    dim = tf.shape(mu)[1]

    # Generate a random tensor
    epsilon = tf.keras.backend.random_normal(shape = (batch,dim))

    # Combine the inputs and noise
    return mu + sigma * epsilon

### Encoder

<img src="https://drive.google.com/uc?export=view&id=1eoxFK_UVSHd3a_5EHcCU8F8QDZlPiXfW" width="60%" height="60%"/>

In [6]:
def encoder_layer(inputs,latent_dim):
  # Add the Conv2D followed by BatchNormalization
  x = tf.keras.layers.Conv2D(32,kernel_size = 3,strides = 2,padding = 'same',activation = 'relu',name = 'encode_conv1')(inputs)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Conv2D(64,kernel_size = 3,strides = 2,padding = 'same',activation = 'relu',name = 'encode_conv2')(x)

  # Assign to a different variable so can extract the shape later
  batch_2 = tf.keras.layers.BatchNormalization()(x)

  # Flatten the features and feed into the Dense network
  x = tf.keras.layers.Flatten(name = 'encode_flatten')(batch_2)

  # we arbitrarily used 20 units here but feel free to change and see what results you get
  x = tf.keras.layers.Dense(20,activation = 'relu',name = 'encode_dense')(x)
  x = tf.keras.layers.BatchNormalization()(x)

  # Add output Dense network for mu and sigma, units equal to the declared latend_dim
  mu = tf.keras.layers.Dense(latent_dim,name = 'latent_mu')(x)
  sigma = tf.keras.layers.Dense(latent_dim,name = 'latent_sigma')(x)

  return mu,sigma,batch_2.shape

In [19]:
def encoder_model(latent_dim,input_shape):
  # Declare the inputs tensor with the given shape
  inputs = tf.keras.layers.Input(shape = input_shape)

  # Get the output of the encoder_layers() function
  mu,sigma,conv_shape = encoder_layer(inputs,latent_dim = LATENT_DIM)

  # Feed mu and sigma to the sampling layer
  z = Sampling()((mu,sigma))

  # Build the whole encoder model
  model = tf.keras.Model(inputs,outputs = [mu,sigma,z])

  return model,conv_shape

### Decoder

In [8]:
def decoder_layer(inputs,conv_shape):
  units = conv_shape[1] * conv_shape[2] * conv_shape[3]
  x = tf.keras.layers.Dense(units,activation = 'relu',name = 'decode_dense1')(inputs)
  x = tf.keras.layers.BatchNormalization()(x)

  # Reshape the output using the conv_shape dimensions
  x = tf.keras.layers.Reshape((conv_shape[1],conv_shape[2],conv_shape[3]),name = 'decode_reshape')(x)

  # Upsample the features back to the original dimensions
  x = tf.keras.layers.Conv2DTranspose(64,kernel_size = 3,strides = 2,padding = 'same',activation = 'relu',name = 'decode_conv2d_2')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Conv2DTranspose(32,kernel_size = 3,strides = 2,padding = 'same',activation = 'relu',name = 'decode_conv2d_3')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Conv2DTranspose(1,kernel_size = 3,strides = 2,padding = 'same', activation = 'sigmoid',name = 'decode_final')(x)

  return x

In [24]:
def decoder_model(latend_dim,conv_shape):
  # Set the inputs to the shape of the latend space
  inputs = tf.keras.layers.Input(shape = (latend_dim,))

  # Get the output of the decoder layers
  outputs = decoder_layer(inputs,conv_shape)

  # Declare the inputs and outputs of the model
  model = tf.keras.Model(inputs,outputs)

  return model

# Kullback-Leibler Divergence

In [11]:
def kl_reconstruction_loss(mu,sigma):
  kl_loss = 1 + sigma - tf.square(mu) - tf.math.exp(sigma)
  kl_loss = tf.reduce_mean(kl_loss) * -0.5

  return kl_loss

# VAE Model

In [14]:
def vae_model(encoder,decoder,input_shape):
  # Set the inputs
  inputs = tf.keras.layers.Input(shape = input_shape)

  # Get the mu,sigma and z from the encoder output
  mu,sigma,z = encoder(inputs)

  # Get the reconstructed from the decoder
  reconstructed = decoder(z)

  # Define the inputs and outputs of the VAE
  model = tf.keras.Model(inputs,outputs = reconstructed)

  # Add the KL loss
  loss = kl_reconstruction_loss(mu,sigma)
  model.add_loss(loss)

  return model

In [20]:
def get_models(input_shape,latent_dim):
  encoder,conv_shape = encoder_model(latent_dim = latent_dim,input_shape = input_shape)
  decoder = decoder_model(latent_dim,conv_shape = conv_shape)
  vae = vae_model(encoder,decoder,input_shape = input_shape)
  return encoder,decoder,vae

In [25]:
# Get the encoder,decoder and 'master' model [VAE]
encoder,decoder,vae = get_models(input_shape = (28,28,1,),latent_dim = LATENT_DIM)

# Train the Model

In [26]:
# Define loss function and optimizers
ootimizer = tf.keras.optimizers.Adam()
loss_metric = tf.keras.metrics.Mean()
bce_loss = tf.keras.losses.BinaryCrossentropy()

In [None]:
def generate_and_save_images(model,epoch,step,test_input):
  # Generate iamges from the test input
  predictions = model.predict(test_input)

  # Plot the results
  fig = plt.figure(figsize = (4,4))

  for i in range(predictions.shape[0]):
    plt.subplot(4,4,i+1)
    plt.imshow(predictions[i,:,:,0],cmap = 'gray')
    plt.axis('off')

  # tight_layout minimized the overlap between 2 subplots
  fig.suptitle("epoch: {}, step: {}".format(epoch,step))
  plt.savefig('image_at_epoch_{:04d}_step{:04d}.png'.format(epoch,step))
  plt.show()