# Simple AutoEncoder training and inference example

**Description**: in this notebook, we showcase the training process and inference capabilities of a simple autoencoder model.

## Imports, definitions and setup

In [None]:
!git clone https://github.com/peiva-git/deep_learning_project.git
%cd deep_learning_project
!pip install -e .

In [3]:
import dlproject as dlp
import matplotlib.pyplot as plt

import os.path

## Load the MNIST dataset

In [None]:
dataset_builder = dlp.data.MNISTDatasetBuilder()
dataset_builder.preprocess_dataset()
train_data, test_data = dataset_builder.train_data, dataset_builder.test_data
noisy_train_data, noisy_test_data = dataset_builder.noisy_train_data, dataset_builder.noisy_test_data

In [None]:
n = 10
plt.figure(figsize=(20, 2))
for i in range(1, n + 1):
    ax = plt.subplot(1, n, i)
    plt.imshow(noisy_train_data[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

## Instantiate the model

In [5]:
autoencoder = dlp.models.SimpleAutoencoder(input_shape=(28, 28, 1))
autoencoder.model.compile(optimizer='adam', loss='binary_crossentropy')

## Train the model

### Testing the model

First, we train the model to reconstruct the image that's given as an input. The reconstructed images should be similar, but not exactly the same.
We also save the model for later use.

In [None]:
autoencoder.model.fit(
    x=train_data,
    y=train_data,
    epochs=50,
    batch_size=128,
    shuffle=True,
    validation_data=(test_data, test_data)
)

autoencoder.model.save(os.path.join('output', 'models', autoencoder.model.name + '.keras'))

Display the results.

In [None]:
decoded_imgs = autoencoder.model.predict(test_data)

n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n + 1):
    # Display original
    ax = plt.subplot(2, n, i)
    plt.imshow(test_data[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

### Denoise images

Secondly, we retrain the model to reconstruct the image from a noisy input.

In [None]:
autoencoder.model.fit(
    x=noisy_train_data,
    y=train_data,
    epochs=100,
    batch_size=128,
    shuffle=True,
    validation_data=(noisy_test_data, test_data)
)

Let's take a look at the results. Top, the ground truth digits fed to the network, than the noisy version and finally the digits are reconstructed by the network. It seems to work pretty well.

In [None]:
import random

def display_random_images(test_data, noisy_test_data, autoencoder_model, num_images=10):
    # Randomly select 10 indices from the test dataset
    random_indices = random.sample(range(len(test_data)), num_images)

    plt.figure(figsize=(15, 4))

    for i, idx in enumerate(random_indices):
        # Original clean image
        plt.subplot(3, num_images, i + 1)
        plt.imshow(test_data[idx].reshape(28, 28), cmap='gray')
        plt.axis('off')

        # Noisy image
        plt.subplot(3, num_images, num_images + i + 1)
        plt.imshow(noisy_test_data[idx].reshape(28, 28), cmap='gray')
        plt.axis('off')

        # Predicted output from the autoencoder
        predicted_output = autoencoder_model.predict(noisy_test_data[idx].reshape(1, 28, 28, 1))
        plt.subplot(3, num_images, 2 * num_images + i + 1)
        plt.imshow(predicted_output[0].reshape(28, 28), cmap='gray')
        plt.axis('off')

    plt.show()

display_random_images(test_data, noisy_test_data, autoencoder.model)
