<a href="https://colab.research.google.com/github/wiso/TutorialML-AtlasItalia2022/blob/main/notebooks/3.2-VariationalAutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from matplotlib import pyplot as plt
import seaborn as sns
from scipy import stats

## Download the dataset and preprocess

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
nclasses = 10
# preprocessing
train_images = train_images / 255.
test_images = test_images / 255.

## Define the model
The model is quite similar to the previous. The differnece is that the latent space is represented by random variables. We choose to distribute this random variables as independent normal distributions.

In the loss we add a regularization term to impose that the latent space is distributed a the prior (independent normal distribution with average 0 and width 1).

In [None]:
class Autoencoder(tf.keras.models.Model):
    def __init__(self, latent_dim, kl_weight=0.01):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.prior = tfp.distributions.Independent(tfp.distributions.Normal(loc=tf.zeros(latent_dim), scale=1), reinterpreted_batch_ndims=1)

        # compute the number of parameters in the latest space (e.g. 2 * latent space dim, since we have the mean and the width)
        latent_dim_params = tfp.layers.IndependentNormal.params_size(latent_dim)
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim_params, activation=None),  # this layer encode the means and the widths of the distributions
            tfp.layers.IndependentNormal(
                latent_dim,
                activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior, weight=kl_weight)),  # add the regularization
        ])

        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=[latent_dim]),
            tf.keras.layers.Dense(7 * 7 * 64, activation='relu'),
            tf.keras.layers.Reshape((7, 7, 64)),
            tf.keras.layers.Conv2DTranspose(64, kernel_size=(3, 3), strides=(2, 2), padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), strides=(2, 2), padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(1, kernel_size=(3, 3), strides=(1, 1), padding='SAME', activation='sigmoid'),
            tf.keras.layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


latent_dim = 16
kl_weight = 0.001  # with higher latent dim use smaller weight, e.g. 2 -> 0.01, 16 -> 0.0005
autoencoder = Autoencoder(latent_dim, kl_weight)
# take into account that the loss is the Huber loss plus the regularization terms
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.Huber(), metrics=tf.keras.losses.Huber())

## Train

In [None]:
history = autoencoder.fit(train_images, train_images,
                epochs=10,
                batch_size=512,
                validation_data=(test_images, test_images),
)

## Evaluate the model
For each image in the test dataset, compute the encoded version (a sampling from the learned latent space distribution) and the decoded version

In [None]:
encoded_imgs = autoencoder.encoder(test_images)
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

Visualize the encoded version of the first 20 images

In [None]:
fig, ax = plt.subplots(figsize=(15, 4))
sns.heatmap(encoded_imgs[:20, :].numpy().T, ax=ax, square=True)
plt.xlabel('image index')
plt.ylabel('latent space')
plt.show()

Check their distributions

In [None]:
from scipy import stats

fig, ax = plt.subplots()
ax.hist(encoded_imgs.sample().numpy()[:, :10], bins=50, density=True, stacked=False, histtype='step', linewidth=1.5)
xspace = np.linspace(-5, 5, 100)
y = stats.norm(0, 1).pdf(xspace)
ax.fill_between(xspace, y, color='0.7', zorder=-1, label='N[0,1]')
ax.legend()
plt.show()

In [None]:
n = 10
fig = plt.figure(figsize=(20, 4))
for i in range(n):
    ax = fig.add_subplot(2, n, i + 1)
    ax.imshow(test_images[i], cmap='binary')
    if i == 0:
      ax.set_ylabel('original')
    
    ax = fig.add_subplot(2, n, i + 1 + n)
    ax.imshow(decoded_imgs[i, :, :], cmap='binary')
    if i == 0:
      ax.set_ylabel('reconstructed')


for ax in fig.get_axes():
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

plt.show()

In [None]:
noise = autoencoder.prior.sample(10)
decoded_imgs = autoencoder.decoder(noise)

fig, axs = plt.subplots(1, 10, figsize=(20, 2))
for ax, decoded_img in zip(axs.flat, decoded_imgs):
    ax.imshow(decoded_img, cmap='binary')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

In [None]:
n = 20

fig, axs = plt.subplots(n, ncols=n, figsize=(15, 15))
noise = np.random.multivariate_normal(np.zeros(latent_dim), np.eye(latent_dim))
for inoise1, noise1 in enumerate(np.linspace(0.05, 0.95, n)):
    for inoise2, noise2 in enumerate(np.linspace(0.05, 0.95, n)):
        noise[0] = stats.norm(0, 1).ppf(noise1)
        noise[1] = stats.norm(0, 1).ppf(noise2)
        decoded_img = autoencoder.decoder(np.expand_dims(noise, axis=0)).numpy()[0]
        axs[inoise1, inoise2].imshow(decoded_img, cmap='gray')

for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

In [None]:
def interpolate_and_show(test_image1, test_image2, decoder, encoder):
    encoded1, encoded2 = encoder(np.stack([test_image1, test_image2])).mean().numpy()
    t = np.expand_dims(np.linspace(0, 1, 10), axis=-1)
    encoded_steps = encoded1 * (1 - t) + encoded2 * t

    fig, axs = plt.subplots(1, 10 + 2, figsize=(15, 5))

    axs[0].imshow(test_image1, cmap='binary')
    axs[-1].imshow(test_image2, cmap='binary')
    axs[0].set_title('first')
    axs[-1].set_title('second')

    for encoded_step, ax in zip(encoded_steps, axs[1:-1]):
        img = decoder(np.expand_dims(encoded_step, axis=0)).numpy()[0]
        ax.imshow(img, cmap='binary')

    for ax in axs:    
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('equal')

for ilabel in range(nclasses):
    test_image1 = test_images[test_labels == ilabel][0]
    test_image2 = test_images[test_labels == ilabel][1]

    interpolate_and_show(test_image1, test_image2, autoencoder.decoder, autoencoder.encoder)  