<a href="https://colab.research.google.com/github/wiso/TutorialML-AtlasItalia2022/blob/main/notebooks/3.1-AutoEncoder_denoise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Autoencoder denoise
Let use an autoencoder to remove the noise from the input images

In [None]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
import seaborn as sns

## Download and preprocess input images

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# preprocessing
train_images = train_images / 255.
test_images = test_images / 255.

## Artificially introduce noise
Add a random normal noise to each pixel

In [None]:
noise_factor = 0.2
train_images_noisy = train_images + noise_factor * np.random.normal(size=train_images.shape)
test_images_noisy = test_images + noise_factor * np.random.normal(size=test_images.shape) 

train_images_noisy = np.clip(train_images_noisy, 0., 1.)
test_images_noisy = np.clip(test_images_noisy, 0., 1.)

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(15, 3))
plt.figure(figsize=(20, 2))
for img, img_noisy, ax_top, ax_bottom in zip(train_images[:10], train_images_noisy[:10], axs[0].flat, axs[1].flat):
    ax_top.imshow(img, cmap='gray')
    ax_bottom.imshow(img_noisy, cmap='gray')
for ax in axs.flat:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
plt.show()

## Define the autoencoder

In [None]:
class Autoencoder(tf.keras.models.Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.latent_dim = latent_dim   
        self.encoder = tf.keras.Sequential([
               tf.keras.layers.Flatten(),
               tf.keras.layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(784, activation='sigmoid'),
            tf.keras.layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

latent_dim = 64 
autoencoder = Autoencoder(latent_dim)


## Train the autoencoder
As input use the noisy image, as output the orignal ones

In [None]:
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

autoencoder.fit(train_images_noisy, train_images,
                epochs=10,
                shuffle=True,
                validation_data=(test_images_noisy, test_images))

## Test

In [None]:
encoded_imgs = autoencoder.encoder(test_images_noisy).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()

In [None]:
n = 10
fig, axs = plt.subplots(3, n, figsize=(20, 5))
for i in range(n):
    axs[0][i].set_title("original + noise")
    axs[0][i].imshow(test_images_noisy[i], cmap='gray')
   
    axs[1][i].set_title("reconstructed")
    axs[1][i].imshow(decoded_imgs[i], cmap='gray')
    
    axs[2][i].set_title("original")
    axs[2][i].imshow(test_images[i], cmap='gray')
    
for ax in fig.get_axes():
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

plt.show()