In [None]:
"""all necessary imports"""

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import Tensor
import cv2
import glob
import os

In [None]:
"""selects 1000 images randomly"""
"""set the path to where the training set is saved on your system"""
import os
filenames = [img for img in glob.glob("path/to/images/*.jpg")]
np.random.shuffle(filenames)
filenames = filenames[:1000]
images = [cv2.imread(img) for img in filenames]
for i in range(0,len(images)):
    images[i] = cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB)
    images[i] = cv2.resize(images[i], (400,400))
images = np.array(images)
images = images.astype('float32')/255

In [None]:
"""custom layer which produces the distribution which is sampled in the VAE"""

class Latent_features(tf.keras.layers.Layer):
    def call(self, inputs) -> Tensor:
        dist_mean, dist_log_var = inputs
        batch = tf.shape(dist_mean)[0]
        dim = tf.shape(dist_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return dist_mean + tf.exp(0.5 * dist_log_var) * epsilon

In [None]:
class RVAE(tf.keras.Model):
    
    """builds encoder and decoder sub-networks"""
    
    def __init__(self):
        super(RVAE, self).__init__()
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        
        autoencoder_input = keras.Input(shape=(400,400,3))
        encoded_img = self.encoder(autoencoder_input)[2]
        decoded_img = self.decoder(encoded_img)
        self.autoencoder = keras.Model(autoencoder_input, decoded_img)
    
    """custom call function necessary for custom keras model"""
    def call(self, image):
        encoded = self.encoder(image)[2]
        decoded = self.decoder(encoded)
        return decoded
    
    """utility function for batch norm and leakyrelu"""
    def activation_block(self, inputs) -> Tensor:
        BN = tf.keras.layers.BatchNormalization(momentum=0.99) (inputs)
        activated = tf.keras.layers.LeakyReLU()(BN)
        return activated
    
    """implements skip connection for the resnet component of the architecture"""
    def residual_block(self, inputs: Tensor, scale_change: bool, polarity, kernel_size = 2, filters = 64, stride_size = 2) -> Tensor:
        #POLARITY IS 0 IF DOWNSCALING (ENCODER) 1 IF UPSCALING (DECODER)
        if polarity == 0:
            output_1 = tf.keras.layers.Conv2D(filters, (kernel_size, kernel_size), strides = (1 if not scale_change else stride_size), padding = 'same') (inputs)
            output_1 = self.activation_block(output_1)
            output_1 = tf.keras.layers.Conv2D(filters, (kernel_size, kernel_size), strides = (1), padding = 'same') (output_1)
            if scale_change == True:
                inputs = tf.keras.layers.Conv2D(filters, (kernel_size, kernel_size), strides = (stride_size), padding = 'same') (inputs)
        
        else:
            output_1 = tf.keras.layers.Conv2DTranspose(filters, (kernel_size, kernel_size), strides = (1 if not scale_change else stride_size), padding = 'same') (inputs)
            output_1 = self.activation_block(output_1)
            output_1 = tf.keras.layers.Conv2DTranspose(filters, (kernel_size, kernel_size), strides = (1), padding = 'same') (output_1)
            if scale_change == True:
                inputs = tf.keras.layers.Conv2DTranspose(filters, (kernel_size, kernel_size), strides = (stride_size), padding = 'same') (inputs)
        
        output_2 = tf.keras.layers.Add()([inputs, output_1])
        output_2 = self.activation_block(output_2)
        
        return(output_2)
        
    """returns the encoder model, a sub-architecture of the autoencoder"""
    def build_encoder(self):
        E_input = tf.keras.layers.Input(shape = (400,400,3), name = 'original_image')
        E = tf.keras.layers.Conv2D(32, (1,1), strides=(1), padding = 'same') (E_input)
        E = self.activation_block(E)
        
        block_depths = [2,5,5,2]
        filters = [32,64,64,128]
        
        for i in range(0,len(block_depths)):
            for j in range(0,block_depths[i]):
                E = self.residual_block(E, (j==0 and i!=0), 0, 2, filters[i], 2)

        
            
        self.reshape_dims = E.shape
            
        E = tf.keras.layers.Flatten() (E)
        
        self.flatten_dims = E.shape
        
        E = (tf.keras.layers.LeakyReLU()) (E)
        
        distribution_mean = tf.keras.layers.Dense(16, name='mean')(E)
        distribution_variance = tf.keras.layers.Dense(16, name='log_variance')(E)
        latent_encoding = Latent_features()([distribution_mean, distribution_variance])
        
        
        encoder = keras.Model(E_input, [distribution_mean, distribution_variance, latent_encoding], name="encoder")
        
                
        return encoder
    
    """returns the decoder model, a sub-network for autoencoder architecture"""
    def build_decoder(self):
        decoder_input = keras.Input(shape=(16,))
        D = tf.keras.layers.LeakyReLU() (decoder_input)
        D = (tf.keras.layers.Dense(self.flatten_dims[1])) (D)
        D = (tf.keras.layers.LeakyReLU()) (D)
        D = (tf.keras.layers.Reshape((self.reshape_dims[1], self.reshape_dims[2], self.reshape_dims[3]))) (D)
        
        block_depths = [2,5,5,2]
        filters = [128,64,64,32]
        
        for i in range(0,len(block_depths)):
            for j in range(0,block_depths[i]):
                D = self.residual_block(D, (j==0 and i!=0), 1, 2, filters[i], 2)
            
        D = tf.keras.layers.Conv2DTranspose(3, (1,1), strides = (1), padding = 'same') (D)
        D = self.activation_block(D)
        
        
        decoded = (tf.keras.layers.Activation('sigmoid')) (D)
        
        decoder = tf.keras.Model(inputs = decoder_input, outputs = decoded)
        
        return decoder

    
    """necessary for loss trackers used for custom loss function"""
    @property
    def metrics(self):
        return [self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            mean, log_var, latent = self.encoder(data)
            reconstruction = self.decoder(latent)
            reconstruction_loss = tf.reduce_mean(tf.reduce_sum(keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)))
            kl_loss = -0.5 * (1 + log_var - tf.square(mean) - tf.exp(log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {"loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),}


In [None]:
"""compiling the RVAE with RMSprop"""
opt = tf.keras.optimizers.RMSprop(learning_rate = 0.0001)
rvae = RVAE()


rvae.compile(optimizer=opt)