In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

def create_autoencoder(input_dim, encoding_dim):
    # Define the encoder layers
    encoder_inputs = layers.Input(shape=(input_dim,))
    x = encoder_inputs
    for dim in encoding_dim:
        x = layers.Dense(dim, activation='tanh')(x)
    encoder_outputs = x
    y = encoder_outputs
    for dim in encoding_dim[::-1][1:]:
        y = layers.Dense(dim, activation='tanh')(y)
    decoder_outputs = layers.Dense(input_dim, activation='tanh')(y)
    autoencoder = models.Model(encoder_inputs, decoder_outputs)
    encoder = models.Model(encoder_inputs, encoder_outputs)
    
    return autoencoder, encoder

def test_mnist():
    # Load the MNIST dataset
    (x_train, _), (x_test, _) = mnist.load_data()
    
    # Normalize and flatten the data
    x_train = x_train.astype('float32') / 255.
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.astype('float32') / 255.
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
    
    # Create the autoencoder model
    autoencoder, encoder = create_autoencoder(784, [256, 64])
    
    # Compile and train the model
    autoencoder.compile(optimizer=optimizers.Adam(0.001), loss=losses.MeanSquaredError())
    autoencoder.fit(x_train, x_train, epochs=10, batch_size=256, shuffle=True, validation_data=(x_test, x_test))
    
    # Generate reconstructions
    num_images = 10
    np.random.seed(42)
    random_test_images = np.random.randint(x_test.shape[0], size=num_images)
    
    encoded_imgs = encoder.predict(x_test)
    decoded_imgs = autoencoder.predict(x_test)
    
    # Plot the original and reconstructed images
    plt.figure(figsize=(20, 4))
    for i, image_idx in enumerate(random_test_images):
        # plot original image
        ax = plt.subplot(2, num_images, i + 1)
        plt.imshow(x_test[image_idx].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        # plot reconstructed image
        ax = plt.subplot(2, num_images, num_images + i + 1)
        plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

if __name__ == '__main__':
    test_mnist()