<a href="https://colab.research.google.com/github/tmohammad78/deep-learning-projects/blob/feature%2Fvariational-autoencoder-mnist/variational-autoencoder/mode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [211]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [163]:
def encoder_layers(inputs,latent_dim):
  x = tf.keras.layers.Conv2D(filters=32,padding="same",name="encoder_conv1" , activation="relu" , kernel_size=3 , strides=2)(inputs)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Conv2D(filters=64,padding="same",name="encoder_conv2" , activation="relu" , kernel_size=3 , strides=2)(x)
  batch_2 = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Flatten(name="encoder_flatten")(batch_2)
  x = tf.keras.layers.Dense(20,activation="relu",name="encode_dense")(x)
  x = tf.keras.layers.BatchNormalization()(x)
  
  mu = tf.keras.layers.Dense(latent_dim,name="latent_mu")(x)
  sigma = tf.keras.layers.Dense(latent_dim,name="latent_sigma")(x)
  return mu , sigma , batch_2.shape

In [164]:
class Sampling(tf.keras.layers.Layer):
  def call(self,inputs):
    mu , sigma = inputs
    batch = tf.shape(mu)[0]
    dim = tf.shape(mu)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch,dim))
    return mu + tf.exp(0.5 * sigma) * epsilon
    

In [242]:
def encoder_model(LATENT_DIM,input_shape):
  inputs = tf.keras.layers.Input(shape=input_shape)
  print(inputs.shape)
  mu , sigma , conv_shape = encoder_layers(inputs , latent_dim=LATENT_DIM)
  z = Sampling()((mu,sigma))
  model = tf.keras.Model(inputs,outputs=[mu,sigma,z])
  return model,conv_shape

In [166]:
def decoder_layers(inputs,conv_shape):
  units = conv_shape[1] * conv_shape[2] * conv_shape[3]
  x = tf.keras.layers.Dense(units,activation="relu",name="decoder_dens1")(inputs)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Reshape((conv_shape[1],conv_shape[2],conv_shape[3]),name="decoder_shape")(x)
  x = tf.keras.layers.Conv2DTranspose(filters=64,kernel_size=3,strides=2,padding="same",activation="relu" , name="decoder_conv2d_2")(x)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Conv2DTranspose(filters=32,kernel_size=3,strides=2,padding="same",activation="relu" , name="decoder_conv2d3")(x)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Conv2DTranspose(filters=1,kernel_size=3,strides=1,padding="same",activation="sigmoid" , name="decoder_final")(x)
  return x

In [167]:
def decoder_model(latent_dim,conv_shape):
   inputs = tf.keras.layers.Input(shape=(latent_dim,))
   outputs = decoder_layers(inputs,conv_shape)
   model = tf.keras.Model(inputs,outputs)
   return model

In [168]:
def k1_reconstrcution_loss(output,mu , sigma):
  kl_loss = 1 + sigma - tf.square(mu) - tf.math.exp(sigma)
  return tf.reduce_mean(kl_loss) * -0.5


In [179]:
def vae_model(encoder, decoder , input_shape):
  inputs = tf.keras.layers.Input(shape=(input_shape))
  mu = encoder(inputs)[0]
  sigma = encoder(inputs)[1]
  z = encoder(inputs)[2]
  reconstructed = decoder(z)
  model = tf.keras.Model(inputs=inputs,outputs = reconstructed)
  loss = k1_reconstrcution_loss(z , mu , sigma)
  model.add_loss(loss)
  return model

In [232]:
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
latent_dim = 32
x_train.shape

(60000, 784)

In [246]:

encoder_out , conv_shape = encoder_model(LATENT_DIM = 32,input_shape = (28,28,1))
decoder_out = decoder_model(32,conv_shape)
vae = vae_model(encoder_out, decoder_out, (28,28,1))
epochs = 2
for epoch in range(epochs):
  for step , x_batch_train in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      reconstrcuted = vae(x_batch_train)
      flattened_input  = tf.reshape(x_batch_train,shape=[-1])
      flattened_output  = tf.reshape(reconstrcuted,shape=[-1])
      loss = bce_loss(flattened_input,flattened_output) * 764
      loss += sum(vae.losses)
    grads = tape.gradient(loss,vae.trainable_weights)
    optimizer.apply_gradients(zip(grads,vae.trainable_weights))
    loss_metric(loss)

(None, 28, 28, 1)


ValueError: ignored