In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Dense, InputLayer, Conv2D, Flatten, Reshape, Conv2DTranspose
from tensorflow.keras import Model, layers
import matplotlib.pyplot as plt
import numpy as np

In [None]:
##Parameters
batch_size = 128
epochs =50

In [None]:
##Datasets
#dataset, info = tfds.load('fashion_mnist', with_info=True, as_supervised=True)
#mnist_train, mnist_test = dataset['train'], dataset['test']
(mnist_train, _), (mnist_test, _) = tf.keras.datasets.mnist.load_data()

#def convert_types(image):
#    '''
#    normarize image matrix, and reshape. (28, 29) -> (784,)
#    '''
mnist_train = mnist_train.reshape(mnist_train.shape[0], 28, 28, 1).astype('float32')
mnist_test = mnist_test.reshape(mnist_test.shape[0], 28, 28, 1).astype('float32')
mnist_train /= 255
mnist_test /= 255
mnist_train[mnist_train >= .5] = 1.
mnist_train[mnist_train < .5] = 0.
mnist_test[mnist_test >= .5] = 1.
mnist_test[mnist_test < .5] = 0.

mnist_train = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(10000).batch(batch_size)
mnist_test = tf.data.Dataset.from_tensor_slices(mnist_test).shuffle(10000).batch(batch_size)

In [None]:
##Model
class Encoder(layers.Layer):
    def __init__(self):
        super(Encoder, self).__init__()
        #self.inputlayer = InputLayer(input_shape=784, batch_size=128, dtype='float32')
        self.c1 = Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu')
        self.c2 = Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu')
        self.f = Flatten()
        self.d1 = Dense(units=50)
        self.d2 = Dense(units=50)
    def call(self, x):
        #x = self.inputlayer(x)
        x = self.c1(x)
        x = self.c2(x)
        x = self.f(x)
        mean = self.d1(x)
        logvar = self.d2(x)
        return mean, logvar
    
class ReparameterizationTrick(layers.Layer):
    def __init__(self):
        super(ReparameterizationTrick, self).__init__()
    def call(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        z = eps * tf.exp(logvar* .5) + mean
        return z

class Decoder(layers.Layer):
    def __init__(self):
        super(Decoder, self).__init__()
        self.d3 = Dense(units=7*7*32, activation='relu')
        self.r = Reshape(target_shape=(7, 7, 32))
        self.c3 = Conv2DTranspose(
              filters=64,
              kernel_size=3,
              strides=(2, 2),
              padding="SAME",
              activation='relu')
        self.c4 = Conv2DTranspose(
              filters=32,
              kernel_size=3,
              strides=(2, 2),
              padding="SAME",
              activation='relu')
        self.c5 = Conv2DTranspose(filters=1, kernel_size=3, strides=(1, 1), padding="SAME")
    def call(self, z):
        x = self.d3(z)
        x = self.r(x)
        x = self.c3(x)
        x = self.c4(x)
        x = self.c5(x)
        return x

class Autoencorder(Model):
    def __init__(self):
        super(Autoencorder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.reparameterizationtrick = ReparameterizationTrick()
        
    def call(self, x):        
        mean, logvar = self.encoder(x)
        z = self.reparameterizationtrick(mean, logvar)
        reconstructed = self.decoder(z)
        return reconstructed

model = Autoencorder()

In [None]:
##Setting for optimize
#refered: https://www.tensorflow.org/tutorials/generative/cvae#define_the_loss_function_and_the_optimizer
optimizer = tf.keras.optimizers.Adam(1e-4)
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)

def compute_loss(model, x):
    mean, logvar = model.encoder(x)
    z = model.reparameterizationtrick(mean, logvar)
    x_logit = model.decoder(z)
 
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1,2,3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

In [None]:
##Define train & test
@tf.function
def train_step(image):
     with tf.GradientTape() as tape:
        
        #ELBO
        loss = compute_loss(model, image)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(loss)
        
@tf.function
def test_step(image):
    predictions = model(image)
    t_loss = tf.reduce_mean(tf.square(tf.subtract(predictions, image)))  
    test_loss(t_loss)

In [None]:
##Do train & test
for epoch in range(epochs):
    for image in mnist_train:
        train_step(image)
  
    for test_image in mnist_test:
        test_step(test_image)
  
    template = 'Epoch {}, ELBO: {}'
    print (template.format(epoch+1,
                           - train_loss.result(), 
           )
    )

In [None]:
##Test image
test_image = [test_image for test_image in mnist_test]
plt.imshow(np.array(test_image[0][3]).reshape(28, 28))
plt.gray()

In [None]:
##Decorded Test image
decorded_images = model(test_image[0])
decorded_images[0]
plt.imshow(np.array(decorded_images[3]).reshape(28, 28))
plt.gray()

In [None]:
##Save model
# refered  https://www.tensorflow.org/guide/saved_model
tf.saved_model.save(model, "./")