<a href="https://colab.research.google.com/github/wiso/TutorialML-AtlasItalia2022/blob/main/notebooks/3.0-AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Autoencoder

Simple autoencoder trained on fashion minst using a simple dense neural network

In [None]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# preprocessing
train_images = train_images / 255.
test_images = test_images / 255.

## Define the model

The model is made by two connected parts. The encoder transform the inputs (the pixel values) to the latent space. The decoder transform the latent space to the output image.

The loss evaluate how the input and the output image are different.

In [None]:
class Autoencoder(tf.keras.models.Model):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim   
        self.encoder = tf.keras.Sequential([
               tf.keras.layers.Flatten(),
               tf.keras.layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(28 * 28, activation='sigmoid'),
            tf.keras.layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

latent_dim = 64
autoencoder = Autoencoder(latent_dim)
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

In [None]:
autoencoder.fit?

In [None]:
history = autoencoder.fit(train_images, train_images,
                epochs=30,
                batch_size=512,
                shuffle=True,
                validation_data=(test_images, test_images))

## Try the autoencoder on the test images

In [None]:
encoded_imgs = autoencoder.encoder(test_images).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

Plot the latent space for the first 20 images. Some values are always zero! It means it is not fully using the latest space! We should use some regularization to avoid it

In [None]:
fig, ax = plt.subplots(figsize=(15, 4))
sns.heatmap(encoded_imgs[:20, :], ax=ax, square=True)
ax.set_ylabel('image index')
ax.set_xlabel('latent space')
plt.show()

The distribution of the the values of the latest space, for all the images

In [None]:
fig, ax = plt.subplots()
ax.hist(encoded_imgs.flat, bins=100)
plt.show()

Try to estimate the pdf of the latent space using the test images. Evaluate the means and covariance.

In [None]:
means = np.mean(encoded_imgs, axis=0)
cov = np.cov(encoded_imgs.T)

Compare the input and the output images

In [None]:
n = 10
fig = plt.figure(figsize=(20, 4))
for i in range(n):
    ax = fig.add_subplot(2, n, i + 1)
    ax.imshow(test_images[i], cmap='binary')
    ax.set_title("original")
    
    ax = fig.add_subplot(2, n, i + 1 + n)
    ax.imshow(decoded_imgs[i], cmap='binary')
    ax.set_title("reconstructed")
    
for ax in fig.get_axes():
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

plt.show()

## Generate new images
We can generate some noise and use it as latent space. We can apply the decoder to generate new images. The problem is that we don't know the distribution of the latent space. Let assume it is a multivariate normal distribution, using the mean and the covariance we computed.

In [None]:
fig, axs = plt.subplots(1, 10, figsize=(15, 3))
for ax in axs.flat:
    noise = np.random.multivariate_normal(means, cov, size=(1,))
    decoded_img = autoencoder.decoder(noise).numpy()[0]
    ax.imshow(decoded_img, cmap='binary')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

The result is not very satisfactory for several reason, for example we don't know the distribution of the latent space

## Interpolate between two images
Compute the latent representation of two input images and linear interpolate between them. Then apply the decoder.

In [None]:
nsteps = 10
l = np.linspace(0, 1, 10)
i = l * np.expand_dims(decoded_imgs[0], -1) + (1 - l) * np.expand_dims(decoded_imgs[1], -1)
fig, axs = plt.subplots(1, nsteps, figsize=(15, 3))
for ax, step in zip(axs.flat, range(nsteps)):
    ax.imshow(i[:, :, step], cmap='binary')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')