In [22]:
import tensorflow as tf
from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model

In [66]:
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_validate = x_test[:100]

In [47]:
tf_optimizer = tf.keras.optimizers.Adam(
  learning_rate=1e-1,
  beta_1=0.99,
  epsilon=1e-1)

In [62]:
class Autoencoder(object):
    def __init__(self, optimizer, latent_dim=64):
        self.encoder = tf.keras.Sequential([
          layers.Flatten(),
          layers.Dense(latent_dim, activation='relu'),
        ])
        self.decoder = tf.keras.Sequential([
          layers.Dense(784, activation='sigmoid'),
          layers.Reshape((28, 28))
        ])   
        self.optimizer = optimizer
        self.dtype = tf.float32

    # Defining custom loss
    def __loss(self, x_input):
        latent_represent = self.encoder(x_input)
        recon = self.decoder(latent_represent)
        loss = losses.MeanSquaredError()(x_input, recon)
        return loss
    
    def __grad(self, x_input):
        with tf.GradientTape() as tape:
            loss_value = self.__loss(x_input)
        return loss_value, tape.gradient(loss_value, self.__wrap_training_variables())
    
    def __wrap_training_variables(self):
        var = self.encoder.trainable_variables + self.decoder.trainable_variables
        return var

    def get_weights(self):
        w = []
        if self.encoder.trainable:
            for layer in self.encoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights = weights_biases[0].flatten()
                    biases = weights_biases[1]
                    w.extend(weights)
                    w.extend(biases)
        
        if self.decorder.trainable:
            for layer in self.decorder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights = weights_biases[0].flatten()
                    biases = weights_biases[1]
                    w.extend(weights)
                    w.extend(biases)
                    
        return tf.convert_to_tensor(w, dtype=self.dtype)

    def set_weights(self, w):
        start = 0
        if self.encoder.trainable:
            for layer in self.encoder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights_shape = weights_biases[0].shape
                    nweight = tf.reduce_prod(weights_shape)
                    nbiases = weights_biases[1].size
                    weights_biases = [tf.reshape(w[start: start + nweight], weights_shape), 
                                     w[start + nweight: start + nweight + nbiases]]
                    layer.set_weights(weights_biases)
                    start += (nweight + nbiases)
        if self.decorder.trainable:
            for layer in self.decorder.layers:
                if layer.count_params() > 0:
                    weights_biases = layer.get_weights()
                    weights_shape = weights_biases[0].shape
                    nweight = tf.reduce_prod(weights_shape)
                    nbiases = weights_biases[1].size
                    weights_biases = [tf.reshape(w[start: start + nweight], weights_shape), 
                                     w[start + nweight: start + nweight + nbiases]]
                    layer.set_weights(weights_biases)
                    start += (nweight + nbiases)

    def summary(self):
        self.encoder.summary()
        self.decoder.summary()
        
    # The training function
    def fit(self, x_train, x_validate, tf_epochs=5000, batch_size=512, shuffle_buffer_size=10 * 512):
        train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
        train_dataset = train_dataset.shuffle(shuffle_buffer_size).batch(batch_size)
        validate_data = tf.convert_to_tensor(x_validate, dtype='float32')
        for epoch in range(tf_epochs):
            # Optimization step
            for data in train_dataset:
                loss_value, grads = self.__grad(data)
                self.optimizer.apply_gradients(zip(grads, self.__wrap_training_variables()))
            # if (epoch % 1 == 0):
            print(f"epoch: {epoch}, loss_value: {self.__loss(x_validate)}")

In [67]:
net = Autoencoder(tf_optimizer)

In [68]:
net.fit(x_train, x_validate)

epoch: 0, loss_value: 0.21587608754634857
epoch: 1, loss_value: 0.09167254716157913
epoch: 2, loss_value: 0.06949993968009949
epoch: 3, loss_value: 0.06831905990839005
epoch: 4, loss_value: 0.0667756199836731
epoch: 5, loss_value: 0.06564392894506454
epoch: 6, loss_value: 0.06444252282381058
epoch: 7, loss_value: 0.06301336735486984
epoch: 8, loss_value: 0.061372675001621246
epoch: 9, loss_value: 0.05952657014131546
epoch: 10, loss_value: 0.05759640783071518
epoch: 11, loss_value: 0.055706657469272614
epoch: 12, loss_value: 0.05391809716820717
epoch: 13, loss_value: 0.05227965861558914
epoch: 14, loss_value: 0.05080657824873924
epoch: 15, loss_value: 0.04952077567577362
epoch: 16, loss_value: 0.04831046983599663
epoch: 17, loss_value: 0.04726816341280937
epoch: 18, loss_value: 0.04626231640577316
epoch: 19, loss_value: 0.045367587357759476
epoch: 20, loss_value: 0.04450390487909317
epoch: 21, loss_value: 0.043715331703424454
epoch: 22, loss_value: 0.04294751584529877


KeyboardInterrupt: 