In [1]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


In [0]:
class Sampling(tf.keras.layers.Layer):
  def call(self, inputs):
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [0]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, latent_dim = 32, intermediate_dim=64, name='encoder', **kwargs):
    super(Encoder, self).__init__(**kwargs)
    self.dense_proj = tf.keras.layers.Dense(intermediate_dim, activation='relu')
    self.dense_mean = tf.keras.layers.Dense(latent_dim)
    self.dense_log_var = tf.keras.layers.Dense(latent_dim)
    self.sampling = Sampling()

  def call(self, inputs):
    x = self.dense_proj(inputs)
    z_mean = self.dense_mean(x)
    z_log_var = self.dense_log_var(x)
    z = self.sampling((z_mean, z_log_var))
    return z_mean, z_log_var, z

In [0]:
class Decoder(tf.keras.layers.Layer):
  def __init__(self, original_dim, intermediate_dim=64, name='decoder', **kwargs):
    super(Decoder, self).__init__(**kwargs)
    self.dense_proj = tf.keras.layers.Dense(intermediate_dim, activation='relu')
    self.dense_out = tf.keras.layers.Dense(original_dim, activation='sigmoid')
  
  def call(self, inputs):
    x = self.dense_proj(inputs)
    return self.dense_out(x)

In [0]:
class VariationalAutoEncoder(tf.keras.Model):
  def __init__(self, original_dim, intermediate_dim=64, latent_dim=32, name='autoencoder', **kwargs):
    super(VariationalAutoEncoder, self).__init__(**kwargs)
    self.original_dim = original_dim
    self.intermediate_dim = intermediate_dim
    self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
    self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

  def call(self, inputs):
    z_mean, z_log_var, z = self.encoder(inputs)
    reconstructed = self.decoder(z)

    kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
    self.add_loss(kl_loss)
    return reconstructed


In [8]:
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [12]:
original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer= tf.keras.optimizers.Adam()

mse_loss = tf.keras.losses.MeanSquaredError()
loss_metrics = tf.keras.metrics.Mean()


EPOCH = 3
for epoch in range(EPOCH):
  print('Start of Epoch %d' % (epoch,))
  for step, x_batch_train in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      result = vae(x_batch_train)

      loss = mse_loss(x_batch_train, result)
      loss += sum(vae.losses)
    grades = tape.gradient(loss, vae.trainable_variables)
    optimizer.apply_gradients(zip(grades, vae.trainable_variables))

    loss_metrics(loss)

    if step % 100 == 0:
      print('step %s: mean loss = %s' % (step, loss_metrics.result()))
      

Start of Epoch 0
step 0: mean loss = tf.Tensor(0.35644764, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.12586358, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.09929453, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.089234985, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.08425876, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.08089639, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.07876585, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.07717861, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.07598517, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.074986786, shape=(), dtype=float32)
Start of Epoch 1
step 0: mean loss = tf.Tensor(0.07468259, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.07402489, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.07351677, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.07304671, shape=(), dtype=float32)
st