# Example 3: Autoencoder

Aprox. execution time (T4 GPU): 2 minutes.

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from time import time

In [None]:
# Load and prepare the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize and reshape the images
train_images = train_images.reshape((60000, 28 * 28)).astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28)).astype('float32') / 255

# Define parameters
input_dim = 784  # 28 * 28
latent_vec_dim = 2

In [None]:
# Define the autoencoder architecture
input_layer = Input(shape=(input_dim,))

# Encoder
enc_layer_1 = Dense(500, activation='sigmoid')(input_layer)
enc_layer_2 = Dense(300, activation='sigmoid')(enc_layer_1)
enc_layer_3 = Dense(100, activation='sigmoid')(enc_layer_2)
enc_layer_4 = Dense(latent_vec_dim, activation='tanh')(enc_layer_3)
encoder = enc_layer_4

# Decoder
dec_layer_1 = Dense(100, activation='sigmoid')(encoder)
dec_layer_2 = Dense(300, activation='sigmoid')(dec_layer_1)
dec_layer_3 = Dense(500, activation='sigmoid')(dec_layer_2)
dec_layer_4 = Dense(input_dim, activation='sigmoid')(dec_layer_3)
decoder = dec_layer_4

# Connect encoder and decoder
autoencoder = Model(input_layer, decoder, name="Deep_Autoencoder")

# Latent representation model
latent_model = Model(input_layer, encoder)

# Get summary
autoencoder.summary()

# Compile the autoencoder model
autoencoder.compile(loss='binary_crossentropy', optimizer='adam')

In [None]:
# Train the autoencoder
t0 = time()
history = autoencoder.fit(train_images, train_images, epochs=25, batch_size=128,
                          shuffle=True, validation_data=(test_images, test_images))
t1 = time()
print("Autoencoder training time: %.2g sec" % (t1 - t0))

In [None]:
# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Validation')
plt.ylabel('Binary Cross Entropy Loss')
plt.xlabel('Epoch')
plt.title('Autoencoder Reconstruction Loss', pad=13)
plt.legend(loc='upper right')
plt.show()

In [None]:
# Reconstruct images using the trained autoencoder
reconstructed_images = autoencoder.predict(test_images)

# Display original and reconstructed images
n = 5
plt.figure(figsize=(20, 4))
for i in range(n):
    # Original Image
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(test_images[i].reshape(28, 28), cmap="gray")
    ax.set_title("Original")
    ax.axis('off')

    # Reconstructed Image
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(reconstructed_images[i].reshape(28, 28), cmap="gray")
    ax.set_title("Reconstructed")
    ax.axis('off')
plt.show()

# Generate and plot latent space representation
latent_representation = latent_model.predict(test_images)

plt.figure(figsize=(12, 10))
scatter = sns.scatterplot(x=latent_representation[:,0],
                          y=latent_representation[:,1],
                          hue=test_labels,
                          palette='tab10',
                          legend="full")

plt.title("2D Latent Space Representation of MNIST Digits")
plt.xlabel("First Latent Dimension")
plt.ylabel("Second Latent Dimension")

# Improve the legend
plt.legend(title="Digit", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()