In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, \
    Lambda, Reshape, Conv2DTranspose, Layer, InputLayer, Activation
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import binary_crossentropy, kl_divergence
import tensorflow.keras.backend as K
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import random
import yaml
import numpy as np

latent_dim = 64
batch_size = 64
epochs = 100

config = yaml.safe_load(open("config.yaml", 'r'))
params = config["params"]
encoder_params = config["encoder"]
decoder_params = config["decoder"]

def show_random_data(data_list, label_list, pick=None):
    if not pick:
        pick = random.randint(0, len(data_list))
    plt.title(f"Label : {label_list[pick]}")
    plt.imshow(data_list[pick], 'gray')
    return pick

def sample(inputs):
    z_mean, z_log_var = inputs
    # dim = tf.shape(z_mean)[0]
    eps = tf.random.normal(shape=(1, params["z_dim"]))
    return z_mean + tf.exp(0.5 * z_log_var) * eps



In [12]:
class Sampling(Layer):
    """Uses (mu, log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(Layer):
    """Maps MNIST digits to a triplet (mu, log_var, z)."""

    def __init__(self, params, encoder_input_shape, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.conv = []
        for n_layer in range(len(params["conv_filters"])):

            self.conv.append(Conv2D(
                    params["conv_filters"][n_layer], 
                    params["conv_kernels"][n_layer],
                    (params["conv_strides"][n_layer], params["conv_strides"][n_layer]), 
                    padding='same' 
                )
            )
        self.relu = Activation("relu")
        self.flatten = Flatten()
        self.mu = Dense(params["z_dim"])
        self.log_var = Dense(params["z_dim"])
        self.sampling = Sampling()

    def call(self, inputs):
        x = inputs
        for n_layer in range(len(params["conv_filters"])):
            x = self.conv[n_layer](x)
            x = self.relu(x)
        x = self.flatten(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        z = self.sampling((mu, log_var))
        return mu, log_var, z

class Decoder(Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, params, decoder_input_shape, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.linear = Dense(np.prod((14, 14, 32)), activation="relu")
        self.reshape = Reshape(target_shape=(14, 14, 32))
        self.conv_t = []
        for n_layer in range(len(params["conv_t_filters"])):
            self.conv_t.append(
                Conv2DTranspose(
                    params["conv_t_filters"][n_layer], 
                    params["conv_t_kernels"][n_layer],
                    (params["conv_t_strides"][n_layer], params["conv_t_strides"][n_layer]), 
                    padding='same' 
                )
            )
        self.relu = Activation("relu")
        self.sigmoid = Activation("sigmoid")

    def call(self, inputs):
        x = self.linear(inputs)
        x = self.reshape(x)
        for n_layer in range(len(params["conv_filters"])):
            x = self.conv_t[n_layer](x)
            if n_layer < len(params["conv_t_filters"]) - 1:
                x = self.relu(x)
            else:
                x = self.sigmoid(x)
        return x

class VariationalAutoEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        encoder_input_shape,
        params,
        name="autoencoder",
        **kwargs
    ):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(params, encoder_input_shape)
        self.decoder = Decoder(params, (params["z_dim"]))

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

def plot_latent_images(name, n=20, digit_size=28, additional_msg=''):
    """Plots n x n digit images decoded from the latent space."""

    norm = tfp.distributions.Normal(0, 1)
    grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
    grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
    image_width = digit_size*n
    image_height = image_width
    image = np.zeros((image_height, image_width))

    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z = np.array([[xi], [yi]])
            z = sample(z)
            x_decoded = vae.decoder(z)
            digit = tf.reshape(x_decoded[0], (digit_size, digit_size))
            image[i * digit_size: (i + 1) * digit_size,
                    j * digit_size: (j + 1) * digit_size] = digit.numpy()

    plt.figure(figsize=(10, 10))
    plt.imshow(image, cmap='Greys_r')
    plt.title(f'{epochs} epochs, {latent_dim} Latent Spaces, {batch_size} Batch Size. '+additional_msg)
    plt.savefig('out/'+name+'.png')

In [13]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

train_size = x_train.shape[0]
test_size = x_test.shape[0]

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

x_train = np.expand_dims(x_train, len(x_train.shape))
x_test = np.expand_dims(x_test, len(x_test.shape))

# (width, height, channel)
input_shape = x_train.shape[1:]

original_dim = np.prod(input_shape)


train_dataset = (tf.data.Dataset.from_tensor_slices(x_train)
                 .shuffle(1024).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(x_test)
                .shuffle(1024).batch(batch_size))

In [14]:
train_dataset

<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

In [15]:
params.update(encoder_params)
params.update(decoder_params)

vae = VariationalAutoEncoder(input_shape, params)

optimizer = tf.keras.optimizers.Adam()
bc_loss_fn = tf.keras.losses.BinaryCrossentropy()
loss_metric = tf.keras.metrics.Mean()


In [17]:
metrics = ["r_loss", "kl_loss", "loss"]

epochs = 5
for epoch in range(epochs):

    tf.print(f"Epoch {epoch}/{epochs}")

    progbar = tf.keras.utils.Progbar(len(train_dataset), stateful_metrics=metrics)

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # print(reconstructed.shape)
            # Compute reconstruction loss
            r_loss = bc_loss_fn(x_batch_train, reconstructed)
            # tf.print(vae.losses)
            # break
            total_loss = r_loss + sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(total_loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))
        # r_loss, kl_loss, total_loss = train_step(x_batch_train)

        loss_metric(total_loss)

        progress_values = [("r_loss", r_loss), ("kl_loss", sum(vae.losses)), ("loss", total_loss)]
        progbar.update(step, values=progress_values)

    tf.print("\n")

        # if step % 100 == 0:
        #     plot_latent_images(f"Exp2-epoch{epoch}-step{step}-{loss_metric.result()}")
        #     print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

Epoch 0/5

Epoch 1/5

Epoch 2/5

Epoch 3/5

Epoch 4/5



In [None]:
plot_latent_images(f'{latent_dim}-{batch_size}-{epochs}')