In [4]:
from io import BytesIO
from tensorflow.keras import optimizers, losses

import datetime
import tensorflow_datasets as tfds

import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers, metrics
from tensorflow.keras import backend as K

import matplotlib.pyplot as plt
import numpy as np

In [5]:
BATCH_SIZE = 8 #@param {type:"integer"}
EPOCHS = 20 #@param {type:"integer"}
LEARNING_RATE = 0.0001 #@param {type:"number"}
LATENT_DIM = 128 #@param {type:"integer"}
IMAGE_SIZE = 64 #@param {type:"integer"}
IMAGE_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)

In [6]:
def Sampling():
    def apply(args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(z_log_var / 2) * epsilon

    return apply


def Encoder(input_shape, latent_dim, layer_widths=[128, 128, 128, 128]):
    input = layers.Input(shape=input_shape)
    x = input
    for width in layer_widths:
        x = layers.Conv2D(width, kernel_size=3, padding='same', strides=2)(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)

    shape_before_flattening = K.int_shape(x)[1:]
    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.3)(x)
    z_mean = layers.Dense(latent_dim)(x)
    z_log_var = layers.Dense(latent_dim)(x)
    z = Sampling()([z_mean, z_log_var])

    return models.Model(input, [z_mean, z_log_var, z])


def Decoder(input_shape, latent_dim, layer_widths=[128, 128, 128, 128]):
    input = layers.Input(shape=(latent_dim,))
    x = layers.Dense(1024)(input)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense((K.prod((4, 4 , 128))), activation='relu')(x)
    x = layers.Reshape((4, 4 , 128))(x)
    for width in layer_widths:
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(width, kernel_size=3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)

    x = layers.Conv2D(3, kernel_size=3, padding='same', activation='sigmoid')(x)

    return models.Model(input, x)


class VAE(models.Model):
    def __init__(self, input_shape, latent_dim, beta= 2000):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_shape, latent_dim)
        self.decoder = Decoder(input_shape, latent_dim)
        self.beta = beta    

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return z_mean, z_log_var, reconstruction

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            reconstruction_loss = tf.reduce_mean(
                self.beta * losses.mean_squared_error(data, reconstruction)
            )
            kl_loss = tf.reduce_mean(
                tf.reduce_sum(-0.5*(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)), axis=1)
            )
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    def test_step(self, data): 
        if isinstance(data, tuple):
            data = data[0]
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            self.beta * losses.mean_squared_error(data, reconstruction)
        )
        kl_loss = tf.reduce_mean(
            tf.reduce_sum(-0.5*(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)), axis=1)
        )
        total_loss = reconstruction_loss + kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

        

    def compile(self, optimizer, **kwargs):
        super(VAE, self).compile(optimizer, **kwargs)
        self.optimizer = optimizer

    def save(self, filepath, overwrite=True, include_optimizer=True, save_format=None, signatures=None, options=None):
        self.encoder.save(filepath + '_encoder', overwrite, include_optimizer, save_format, signatures, options)
        self.decoder.save(filepath + '_decoder', overwrite, include_optimizer, save_format, signatures, options)

    def load(self, filepath, compile=True, options=None):
        self.encoder = models.load_model(filepath + '_encoder', compile, options)
        self.decoder = models.load_model(filepath + '_decoder', compile, options)

    def summary(self, line_length=None, positions=None, print_fn=None):
        self.encoder.summary(line_length, positions, print_fn)
        self.decoder.summary(line_length, positions, print_fn)


In [7]:

class CustomTensorboard(tf.keras.callbacks.TensorBoard):
    def __init__(self, log_dir, dataset):
        super().__init__()
        self.log_dir = log_dir
        self.dataset = dataset

    def on_epoch_end(self, epoch, logs=None, **kwargs):
        super().on_epoch_end(epoch, logs, **kwargs)

        images = next(iter(self.dataset))

        _, _, reconstructions = self.model(images)
        reconstructions = reconstructions.numpy()

        fig = plt.figure(figsize=(10, 10))
        for i in range(25):
            plt.subplot(5, 5, i+1)
            plt.imshow(np.hstack((images[i], reconstructions[i])))
            plt.axis('off')
        buf = BytesIO()
        plt.savefig(buf, format='png')
        plt.close(fig)

        image = tf.image.decode_png(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        with tf.summary.create_file_writer(self.log_dir).as_default():
            tf.summary.image("Reconstruction", image, step=epoch)

In [8]:
def preprocess(x):
    x = x['image']
    x = tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE))
    x = tf.cast(x, tf.float32) / 255.0
    return x

dataset = tfds.load('oxford_flowers102', split='train[:80%]', shuffle_files=True)
dataset = dataset.map(lambda x: preprocess(x))
dataset = dataset.map(lambda x: tf.image.random_flip_left_right(x))
dataset = dataset.shuffle(1024, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

val_dataset = tfds.load('oxford_flowers102', split='train[80%:]', shuffle_files=False)
val_dataset = val_dataset.map(lambda x: preprocess(x))
val_dataset = val_dataset.batch(BATCH_SIZE, drop_remainder=True)

test_dataset = tfds.load('oxford_flowers102', split='test', shuffle_files=False)
test_dataset = test_dataset.map(lambda x: preprocess(x))
test_dataset = test_dataset.batch(BATCH_SIZE, drop_remainder=True)

In [9]:
vae = VAE(IMAGE_SHAPE, LATENT_DIM)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = CustomTensorboard(log_dir, val_dataset)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=LEARNING_RATE / 100)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)

vae.compile(optimizers.Adam(learning_rate=LEARNING_RATE))
vae.fit(dataset, epochs=EPOCHS, callbacks=[tensorboard_callback, reduce_lr, early_stopping], validation_data=val_dataset) 

Epoch 1/20
 20/102 [====>.........................] - ETA: 28s - loss: 150.8556 - reconstruction_loss: 150.8128 - kl_loss: 0.0428

KeyboardInterrupt: 