# Update

- Decoder Reconstruction Monet (Like) 
- Train VAE Around 1K Epochs

In [None]:
import matplotlib.pyplot as plt
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model

from tqdm import tqdm
from numpy import asarray
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array

# Create Sampling Layer

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

### Utiliy Convolutional Block

In [None]:
# dense block
def dense_block(x, f, k, s, d=5):
    l = x
    for i in range(d):
        x = layers.Conv2D(f, kernel_size=k, strides=s, padding='SAME')(l)
        l = layers.Concatenate()([l, x])
    return l

# resnetxt block
def resnext_block(x, f=32, r=2, c=4):
    l = []
    for i in range(c):
        m = layers.Conv2D(f//(c*r), k=1)(x)
        m = layers.Conv2D(f//(c*r), k=3)(m)
        m = layers.Conv2D(f, k=1)(m)
        l.append(m)
    m = layer.add(l) 
    return layer.add([x, m])

# residual block
def residual_block(x, f=32, r=4):
    m = layers.Conv2D(x, f//r, k=1)
    m = layers.Conv2D(m, f//r, k=3)
    m = layers.Conv2D(m, f, k=1)
    return layer.add([x, m])

# inception block
def inception_module(x, f=32, r=4):
    a = layers.Conv2D(x, f, k=1)
    b = layers.Conv2D(x, f//3, k=1)
    b = layers.Conv2D(b, f, k=3)
    c = layers.Conv2D(x, f//r, k=1)
    c = layers.Conv2D(c, f, k=5)
    d = layer.MaxPooling2D(x, k=3, s=1)
    d = layers.Conv2D(d, f, k=1)
    return layers.concatenate([a, b, c, d])

# se block
def se_block(x, f, rate=16):
    m = layers.GlobalAveragePooling2D()(x)
    m = layers.Dense(m, f // rate)
    m = layers.Dense(m, f, a='sigmoid')
    return layer.multiply([x, m])

In [None]:
def plot_latent(decoder):
    # display a n*n 2D manifold of digits
    n = 5
    sample_size = 256
    scale = 2.0
    figsize = 15
    figure = np.zeros((sample_size * n, sample_size * n, 3))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            sample = x_decoded[0].reshape(sample_size, sample_size, 3)
            figure[
                i * sample_size : (i + 1) * sample_size,
                j * sample_size : (j + 1) * sample_size,
            ] = sample

    plt.figure(figsize=(figsize, figsize))
    start_range = sample_size // 2
    end_range = n * sample_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, sample_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

# Encoder

In [None]:
latent_dim = 2
total_epoch = 10000
batch_size = 512
img_dim = (256, 256)

encoder_inputs = keras.Input(shape=(*img_dim, 3))

x = layers.Conv2D(16, kernel_size=(3,3), strides=(2,2), padding='SAME')(encoder_inputs)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2D(32, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2D(64, kernel_size=(7,7), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2D(128, kernel_size=(9,9), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2D(256, kernel_size=(11,11), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation="relu")(x)

z_mean    = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z         = Sampling()([z_mean, z_log_var])

encoder = keras.Model(encoder_inputs, 
                      [z_mean, z_log_var, z], name="encoder")

encoder.summary()

## Encoder Plot

In [None]:
plot_model(encoder, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

# Decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(8 * 8 * 256, activation="relu")(latent_inputs)
x = layers.Reshape((8, 8, 256))(x)

x = layers.Conv2DTranspose(256, kernel_size=(11,11), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2DTranspose(128, kernel_size=(9,9), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2DTranspose(64, kernel_size=(7,7), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2DTranspose(32, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

x = layers.Conv2DTranspose(16, kernel_size=(3,3), strides=(2,2), padding='SAME')(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)

decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## Decoder Plots

In [None]:
plot_model(decoder, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

# Modeling

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 256 * 256
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

## Training Data

In [None]:
trn_images = os.listdir('../input/gan-getting-started/monet_jpg')
trn_sizes = []

for i, img_path in enumerate(tqdm(trn_images)):
    img = load_img(os.path.join('../input/gan-getting-started/monet_jpg',
                                f'{img_path}'), target_size=(256, 256))
    img_ary = img_to_array(img)
    trn_sizes.append(img_ary.astype("float32")/255.0)
    
trn_sizes = asarray(trn_sizes)

# GAN Monitoring

In [None]:
# Monitoring GAN 
class GANMonitor(tf.keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(test_horses.take(4)):
            prediction = self.model.gen_G(img)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction)
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = keras.preprocessing.image.array_to_img(prediction)
            prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
        plt.show()
        plt.close()
        
        
# Visualize Recontrucnted Samples Via Decoder (VAE)
# At Specific Epoch Interval (i.e. 200)
class TrainingPlot(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):    
        # Initialization code    
        self.epochs = 0    
    def on_epoch_end(self, batch, logs={}):    
        self.epochs += 1     
        if self.epochs % 200 == 0:
            print('[Reconstructed] Viz Results of Decoder @ Epoch : ', self.epochs)
            plot_latent(decoder)

# Training

In [None]:
plotter = GANMonitor()
plots   = TrainingPlot()
vae     = VAE(encoder, decoder)

# optimizer
vae.compile(optimizer=keras.optimizers.Adam(),run_eagerly=False)

vae.fit(trn_sizes, 
        epochs=total_epoch, 
        batch_size=batch_size,
        callbacks = plots, verbose=0)

# Plot: Reconstruction Decoder

In [None]:
plot_latent(decoder)