<a href="https://colab.research.google.com/github/wiso/TutorialML-AtlasItalia2022/blob/main/notebooks/3.0-AutoEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# preprocessing
train_images = train_images / 255.
test_images = test_images / 255.

In [None]:
# convert the numpy arrays to a tensorflow dataset
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))

In [None]:
class Autoencoder(tf.keras.models.Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim   
        self.encoder = tf.keras.Sequential([
               tf.keras.layers.Flatten(),
               tf.keras.layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(784, activation='sigmoid'),
            tf.keras.layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

latent_dim = 64
autoencoder = Autoencoder(latent_dim)
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

In [None]:
history = autoencoder.fit(train_images, train_images,
                epochs=10,
                batch_size=512,
                shuffle=True,
                validation_data=(test_images, test_images))

In [None]:
encoded_imgs = autoencoder.encoder(test_images).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

In [None]:
fig, ax = plt.subplots(figsize=(15, 4))
sns.heatmap(encoded_imgs[:20, :], ax=ax, square=True)
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.hist(encoded_imgs.flat, bins=100)
plt.show()

In [None]:
means = np.mean(encoded_imgs, axis=0)
cov = np.cov(encoded_imgs.T)

In [None]:
n = 10
fig = plt.figure(figsize=(20, 4))
for i in range(n):
    ax = fig.add_subplot(2, n, i + 1)
    ax.imshow(test_images[i], cmap='binary')
    ax.set_title("original")
    
    ax = fig.add_subplot(2, n, i + 1 + n)
    ax.imshow(decoded_imgs[i], cmap='binary')
    ax.set_title("reconstructed")
    
for ax in fig.get_axes():
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

plt.show()

In [None]:
fig, axs = plt.subplots(1, 10, figsize=(15, 3))
for ax in axs.flat:
    noise = np.random.multivariate_normal(means, cov, size=(1,))
    decoded_img = autoencoder.decoder(noise).numpy()[0]
    ax.imshow(decoded_img, cmap='binary')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

In [None]:
nsteps = 10
l = np.linspace(0, 1, 10)
i = l * np.expand_dims(decoded_imgs[0], -1) + (1 - l) * np.expand_dims(decoded_imgs[1], -1)
fig, axs = plt.subplots(1, nsteps, figsize=(15, 3))
for ax, step in zip(axs.flat, range(nsteps)):
    ax.imshow(i[:, :, step], cmap='binary')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')