<a href="https://colab.research.google.com/github/tmohammad78/deep-learning-projects/blob/main/variational-autoencoder/mnist-vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [17]:
class Sampling(layers.Layer):
  def call(self,input):
    z_mean , z_log_v = input
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    sigma = tf.exp(0.5 * z_log_v)
    return z_mean + sigma * epsilon


In [18]:
latent_dim = 2

encoder_inputs = keras.Input((28,28,1))
x = layers.Conv2D(32,3,activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64,3,activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim,name="z_mean")(x)
z_log_v = layers.Dense(latent_dim,name="z_log_v")(x)

z = Sampling()([z_mean,z_log_v])

encoder = keras.Model(encoder_inputs,[z_mean, z_log_v, z],name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 14, 14, 32)   320         input_7[0][0]                    
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 7, 7, 64)     18496       conv2d_8[0][0]                   
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 3136)         0           conv2d_9[0][0]                   
____________________________________________________________________________________________

In [19]:
decoder_input  = keras.Input(shape=(latent_dim,))

x = layers.Dense(7 * 7 * 64 , activation= "relu")(decoder_input)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)

decoder = keras.Model(decoder_input,decoder_outputs,name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_6 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 28, 28, 1)         289       
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_______________________________________________________