In [None]:
# *************************** RECURRENT AUTOENCODER *********************************** #

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras as K

In [2]:
# load data set
(X_train_full, y_train_full), (X_test, y_test) = K.datasets.mnist.load_data()
X_train_full = X_train_full.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

In [23]:
# define class
class Recurrent:
    def __init__(self):
        
        self.encoder = K.models.Sequential([
            K.layers.LSTM(units=100, return_sequences=True, input_shape=[28, 28]),
            K.layers.LSTM(units=30)
        ])
        
        self.decoder = K.models.Sequential([
            K.layers.RepeatVector(28, input_shape=[30]),
            K.layers.LSTM(units=100, return_sequences=True),
            
            K.layers.TimeDistributed(K.layers.Dense(28, activation='sigmoid'))
            
        ])
        
        self.model = K.models.Sequential([self.encoder, self.decoder])
        
    def train(self, X_train, X_test, epochs=10, batch_size=32):
        
        self.model.compile(loss='binary_crossentropy', 
                           optimizer='adam', metrics=['accuracy'])
        
        self.model.fit(X_train, X_train, 
                       epochs=epochs, 
                       batch_size=batch_size,
                       validation_data=(X_test, X_test))
        
    # Plot Accuracy & Loss curves
    def plot(self):
        fig = plt.figure(figsize=(15, 6))
        
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.set_title("Accuracy")
        ax1.set_xlabel("Number of epochs")
        ax1.plot(self.h.history['accuracy'], color='blue')
        
        ax2 = fig.add_subplot(1, 2, 2)
        ax2.set_title("Loss")
        ax2.set_xlabel("Number of epochs")
        ax2.plot(self.h.history['loss'], color='red')
        
        plt.grid(True)
        plt.show()
        
    def show_reconstructions(self, images=X_valid, n_images=5):
    
        reconstructions = self.model.predict(images[:n_images])
        fig = plt.figure(figsize=(n_images * 1.5, 3))
        for image_index in range(n_images):
            plt.subplot(2, n_images, 1 + image_index)
            plot_image(images[image_index])
            plt.subplot(2, n_images, 1 + n_images + image_index)
            plot_image(reconstructions[image_index])   
        
        

In [24]:
model = Recurrent()

In [25]:
model.train(X_train, X_test, epochs=5, batch_size=512)

Train on 55000 samples, validate on 10000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
 1536/55000 [..............................] - ETA: 1:10 - loss: 0.2906 - accuracy: 0.8080

KeyboardInterrupt: 