In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

def create_models(input_dim, encoding_dim):
    # encoder model
    inputs = tf.keras.Input(shape=(input_dim,))
    encoded = layers.Dense(encoding_dim, activation='relu')(inputs)
    encoder = tf.keras.Model(inputs, encoded)

    # decoder model
    decoder_inputs = tf.keras.Input(shape=(encoding_dim,))
    decoded = layers.Dense(input_dim, activation='sigmoid')(decoder_inputs)
    decoder = tf.keras.Model(decoder_inputs, decoded)

    # autoencoder model
    autoencoder_inputs = tf.keras.Input(shape=(input_dim,))
    autoencoder = tf.keras.Model(autoencoder_inputs, decoder(encoder(autoencoder_inputs)))

    return encoder, decoder, autoencoder

def plot_images(original, reconstructed):
    n = len(original)
    plt.figure(figsize=(20, 4))

    for i in range(n):
        # Display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(original[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # Display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(reconstructed[i].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()

def train_autoencoder(autoencoder, x_train, x_test, epochs=10, batch_size=128):
    autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

    autoencoder.fit(x_train, x_train,
                    epochs=epochs,
                    batch_size=batch_size,
                    shuffle=True,
                    validation_data=(x_test, x_test))

def run_example():
    (x_train, _), (x_test, _) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

    encoding_dim = 32
    input_dim = x_train.shape[1]
    encoder, decoder, autoencoder = create_models(input_dim, encoding_dim)

    train_autoencoder(autoencoder, x_train, x_test, epochs=10, batch_size=128)

    # Encode and decode some digits
    encoded_imgs = encoder.predict(x_test)
    decoded_imgs = decoder.predict(encoded_imgs)

    # Use Matplotlib
    n = 10  # How many digits we will display
    plot_images(x_test[:n], decoded_imgs[:n])

if __name__ == "__main__":
    run_example()
