In [3]:
import tensorflow as tf
from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model

In [4]:
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_validate = x_test[:100]

In [5]:
tf_optimizer = tf.keras.optimizers.Adam(
  learning_rate=1e-1,
  beta_1=0.99,
  epsilon=1e-1)

In [49]:
class Autoencoder(object):
    def __init__(self, latent_dim=64):
        self.encoder = tf.keras.Sequential([
          layers.Flatten(),
          layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = tf.keras.Sequential([
          layers.Dense(784, activation='sigmoid'),
          layers.Reshape((28, 28))
        ])   
        self.dtype = tf.float32

    # Defining custom loss
    def __loss(self, x_input):
        latent_represent = self.encoder(x_input)
        recon = self.decoder(latent_represent)
        loss = losses.MeanSquaredError()(x_input, recon)
        return loss
    
    def __grad(self, x_input):
        with tf.GradientTape() as tape:
            loss_value = self.__loss(x_input)
        return loss_value, tape.gradient(loss_value, self.__wrap_training_variables())
    
    def __flat_grad(self, grad):
        grad_flat = []
        for g in grad:
            grad_flat.append(tf.reshape(g, [-1]))
        grad_flat =  tf.concat(grad_flat, 0)
        return grad_flat
    
    def __wrap_training_variables(self):
        var = self.encoder.trainable_variables + self.decoder.trainable_variables
        return var

    def get_weights(self):
        w = []
        if self.encoder.trainable:
            for layer in self.encoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights = weights_biases[0].flatten()
                    biases = weights_biases[1]
                    w.extend(weights)
                    w.extend(biases)
        
        if self.decoder.trainable:
            for layer in self.decoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights = weights_biases[0].flatten()
                    biases = weights_biases[1]
                    w.extend(weights)
                    w.extend(biases)
                    
        return tf.convert_to_tensor(w, dtype=self.dtype)

    def set_weights(self, w):
        start = 0
        if self.encoder.trainable:
            for layer in self.encoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights_shape = weights_biases[0].shape
                    nweight = tf.reduce_prod(weights_shape)
                    nbiases = weights_biases[1].size
                    weights_biases = [tf.reshape(w[start: start + nweight], weights_shape), 
                                     w[start + nweight: start + nweight + nbiases]]
                    layer.set_weights(weights_biases)
                    start += (nweight + nbiases)
        if self.decoder.trainable:
            for layer in self.decoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights_shape = weights_biases[0].shape
                    nweight = tf.reduce_prod(weights_shape)
                    nbiases = weights_biases[1].size
                    weights_biases = [tf.reshape(w[start: start + nweight], weights_shape), 
                                     w[start + nweight: start + nweight + nbiases]]
                    layer.set_weights(weights_biases)
                    start += (nweight + nbiases)

    def summary(self):
        self.encoder.summary()
        self.decoder.summary()
        
    # The training function
    def fit(self, x_train, x_validate, tf_optimizer, tf_epochs=5000, batch_size=512, shuffle_buffer_size=10 * 512):
        train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
        train_dataset = train_dataset.shuffle(shuffle_buffer_size).batch(batch_size)
        validate_data = tf.convert_to_tensor(x_validate, dtype='float32')
        for epoch in range(tf_epochs):
            # Optimization step
            for data in train_dataset:
                loss_value, grads = self.__grad(data)
                tf_optimizer.apply_gradients(zip(grads, self.__wrap_training_variables()))
            # if (epoch % 1 == 0):
            print(f"epoch: {epoch}, loss_value: {self.__loss(x_validate)}")

In [50]:
net = Autoencoder()

In [51]:
net.fit(x_train, x_validate, tf_optimizer, 10)

epoch: 0, loss_value: 0.21620357036590576
epoch: 1, loss_value: 0.09263485670089722
epoch: 2, loss_value: 0.0695282444357872
epoch: 3, loss_value: 0.0682690218091011
epoch: 4, loss_value: 0.06695490330457687
epoch: 5, loss_value: 0.06575649231672287
epoch: 6, loss_value: 0.06442543864250183
epoch: 7, loss_value: 0.06294121593236923
epoch: 8, loss_value: 0.06127731502056122
epoch: 9, loss_value: 0.059446100145578384
