<a href="https://colab.research.google.com/github/vatsal9876/image_restoration/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model, backend as K

In [None]:
latent_dim = 128
img_shape = (64, 64, 3) # must change input shape

In [None]:
def latent_sampling(z_args):
    z_mean, z_log_var = z_args
    epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
def vae_structure():
  # Encoder
    inputs = layers.Input(shape=img_shape)
    x = layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(inputs)
    x = layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)

    z_mean = layers.Dense(latent_dim)(x)
    z_log_var = layers.Dense(latent_dim)(x)
    z_args = [z_mean, z_log_var]

    z = layers.Lambda(latent_sampling)(z_args)

    # Decoder
    decoder_input = layers.Input(shape=(latent_dim,))
    x = layers.Dense(16 * 16 * 64, activation='relu')(decoder_input)
    x = layers.Reshape((16, 16, 64))(x)
    x = layers.Conv2DTranspose(64, 3, strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2DTranspose(32, 3, strides=2, activation='relu', padding='same')(x)
    decoded_output = layers.Conv2D(3, 3, activation='sigmoid', padding='same')(x)

    decoder = Model(decoder_input, decoded_output, name='decoder')

    outputs = decoder(z)
    vae = VAEModel(inputs, outputs, z_mean, z_log_var, decoder)
    return vae

In [None]:
class VAEModel(Model):
    def __init__(self, encoder_inputs, decoded_outputs, z_mean, z_log_var, decoder, **kwargs):
        super(VAEModel, self).__init__(**kwargs)
        self.encoder_inputs = encoder_inputs
        self.decoded_outputs = decoded_outputs
        self.z_mean = z_mean
        self.z_log_var = z_log_var
        self.decoder = decoder

    def compile(self, optimizer):
        super(VAEModel, self).compile()
        self.optimizer = optimizer
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [self.loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

    def train_step(self, data):
        noisy_images, clean_images = data
        with tf.GradientTape() as tape:
            z_mean, z_log_var = self.encode(noisy_images)
            z = latent_sampling([z_mean, z_log_var])
            reconstruction = self.decoder(z)

            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(clean_images, reconstruction)
            ) * img_shape[0] * img_shape[1] * img_shape[2]

            kl_loss = -0.5 * tf.reduce_mean(
                1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            )

            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def encode(self, x):
        encoder = tf.keras.Model(self.encoder_inputs, [self.z_mean, self.z_log_var])
        return encoder(x)

    def call(self, inputs):
        z_mean, z_log_var = self.encode(inputs)
        z = latent_sampling([z_mean, z_log_var])
        return self.decoder(z)

In [None]:
vae = vae_structure()
vae.compile(optimizer=tf.keras.optimizers.Adam())