In [28]:
import tensorflow as tf
from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
import tensorflow_probability as tfp
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [2]:
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_validate = x_test[:100]

In [186]:
def relu_grad(x):
    return tf.where(tf.less(x, 0), tf.zeros_like(x), tf.ones_like(x))

In [376]:
class Autoencoder(object):
    def __init__(self, input_shape = (28,28), ngd=False):
        self.model = tf.keras.Sequential([
          layers.InputLayer(input_shape),
          layers.Flatten(name="x0"),
          layers.Dense(512, use_bias=False, name="pre_x1"),
          layers.Activation("relu", name="x1"),
          layers.Dense(256, use_bias=False, name="pre_x2"),
          layers.Activation("relu", name="x2"),
          layers.Dense(128, use_bias=False, name="pre_x3"),
          layers.Activation("relu", name="x3"),
          layers.Dense(64, use_bias=False, name="pre_x4"),
          layers.Activation("relu", name="x4"),
          layers.Dense(32, use_bias=False, name="pre_x5"),
          layers.Activation("relu", name="x5"), # latent representation
          layers.Dense(64, use_bias=False, name="pre_x6"),
          layers.Activation("relu", name="x6"),
          layers.Dense(128, use_bias=False, name="pre_x7"),
          layers.Activation("relu", name="x7"),
          layers.Dense(256, use_bias=False, name="pre_x8"),
          layers.Activation("relu", name="x8"),
          layers.Dense(512, use_bias=False, name="pre_x9"),
          layers.Activation("relu", name="x9"),
          layers.Dense(784, use_bias=False, name="pre_x10"),
          layers.Activation("relu", name="x10"),
          layers.Reshape((28, 28))
        ])
        outputs = []
        pre_outputs = []
        for layer in self.model.layers:
            if layer.name[0] == 'x':
                outputs.append(layer.output)
            elif layer.name[0:3] == "pre":
                pre_outputs.append(layer.output)
        self.extend_model = tf.keras.Model(self.model.input, [self.model.output, outputs, pre_outputs])
        self.dtype = tf.float32
        self.ngd = ngd
        self.sigma = 1e-6
        
    def __loss(self, x_input):
        outputs = self.extend_model(x_input)
        loss = losses.MeanSquaredError()(x_input, outputs[0])
        self.x = outputs[1]
        self.pre_x = outputs[2]
        return loss
    
    def __grad(self, x_input):
        with tf.GradientTape() as tape:
            loss_value = self.__loss(x_input)
        grad = tape.gradient(loss_value, self.__wrap_training_variables())
        
        if self.ngd:
            L = len(self.x) - 1
            self.e = []
            dist = tfp.distributions.Normal(loc=0., scale=self.sigma)
            pre_x_reverse = self.pre_x[::-1]
            pre_output_layers = []
            for layer in self.model.layers:
                if layer.name[0:3] == "pre":
                    pre_output_layers.append(layer)
            self.e.append(relu_grad(pre_x_reverse[0]) * dist.sample(self.x[-1].shape) / x_input.shape[0] * 2)
            for i, layer in enumerate(pre_output_layers[::-1][:-1]): # in a reverse order
                new_e = relu_grad(pre_x_reverse[i+1]) * tf.matmul(self.e[-1], tf.transpose(layer.kernel))
                self.e.append(new_e)
            self.e.append(None)
            self.e = self.e[::-1]
            for i, layer in enumerate(pre_output_layers):
                right = tf.tensordot(self.e[i+1], self.e[i+1], axes=[[0], [0]]) / x_input.shape[0]
                left = tf.tensordot(self.x[i], self.x[i], axes=[[0], [0]]) / x_input.shape[0]
                left_inverse = tf.linalg.inv(left + 1e-3 * tf.eye(left.shape[0]));
                right_inverse = tf.linalg.inv(right + 1e-3 * tf.eye(right.shape[0]))
                grad[i] = 1 / L * tf.linalg.matmul(tf.linalg.matmul(left_inverse, grad[i]), right_inverse)
        return loss_value, grad
    
    def __wrap_training_variables(self):
        var = self.model.trainable_variables
        return var

    def summary(self):
        self.model.summary()
        
    # The training function
    def fit(self, x_train, x_validate, tf_optimizer, tf_epochs=5000, batch_size=1024, shuffle_buffer_size=10 * 512):
        train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
        train_dataset = train_dataset.shuffle(shuffle_buffer_size).batch(batch_size)
        validate_data = tf.convert_to_tensor(x_validate, dtype='float32')
        for epoch in range(tf_epochs):
            # Optimization step
            for data in train_dataset:
                loss_value, grads = self.__grad(data)
                tf_optimizer.apply_gradients(zip(grads, self.__wrap_training_variables()))
            # if (epoch % 1 == 0):
            print(f"epoch: {epoch}, loss_value: {self.__loss(x_validate)}")

In [388]:
net1 = Autoencoder(ngd=True)
net1.summary()

Model: "sequential_134"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 x0 (Flatten)                (None, 784)               0         
                                                                 
 pre_x1 (Dense)              (None, 512)               401408    
                                                                 
 x1 (Activation)             (None, 512)               0         
                                                                 
 pre_x2 (Dense)              (None, 256)               131072    
                                                                 
 x2 (Activation)             (None, 256)               0         
                                                                 
 pre_x3 (Dense)              (None, 128)               32768     
                                                                 
 x3 (Activation)             (None, 128)            

In [389]:
net2 = Autoencoder(ngd=False)
net2.summary()

Model: "sequential_135"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 x0 (Flatten)                (None, 784)               0         
                                                                 
 pre_x1 (Dense)              (None, 512)               401408    
                                                                 
 x1 (Activation)             (None, 512)               0         
                                                                 
 pre_x2 (Dense)              (None, 256)               131072    
                                                                 
 x2 (Activation)             (None, 256)               0         
                                                                 
 pre_x3 (Dense)              (None, 128)               32768     
                                                                 
 x3 (Activation)             (None, 128)            

In [390]:
tf_optimizer = tf.keras.optimizers.Adam(
  learning_rate=1e-1,
  beta_1=0.99,
  epsilon=1e-1)
net2.fit(x_train, x_validate, tf_optimizer, 10)

epoch: 0, loss_value: 0.10087733715772629
epoch: 1, loss_value: 0.09958314150571823
epoch: 2, loss_value: 0.0939842015504837
epoch: 3, loss_value: 0.08475125581026077
epoch: 4, loss_value: 0.07674574106931686
epoch: 5, loss_value: 0.07522127032279968
epoch: 6, loss_value: 0.07492201775312424
epoch: 7, loss_value: 0.07425510883331299
epoch: 8, loss_value: 0.07387492805719376
epoch: 9, loss_value: 0.0737016573548317


In [391]:
tf_optimizer = tf.keras.optimizers.Adam(
  learning_rate=1e-2,
  beta_1=0.99,
  epsilon=1e-1)
net1.fit(x_train, x_validate, tf_optimizer, 10)

epoch: 0, loss_value: 0.07039278000593185
epoch: 1, loss_value: 0.05736871436238289
epoch: 2, loss_value: 0.04855325445532799
epoch: 3, loss_value: 0.040112946182489395
epoch: 4, loss_value: 0.03355675935745239
epoch: 5, loss_value: 0.028497595340013504
epoch: 6, loss_value: 0.02399457059800625
epoch: 7, loss_value: 0.02039281278848648
epoch: 8, loss_value: 0.017971782013773918
epoch: 9, loss_value: 0.01658797077834606
