In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [5]:
(ds_train, ds_test), ds_info = tfds.load('celeb_a:2.0.1', split=['train','test'], shuffle_files=True,
                                          with_info=True)

In [6]:
batch_size = 128

def preprocess(sample):
    image = sample['image']
    image = tf.image.resize(image, [112,112])
    image = tf.cast(image, tf.float32)/255.
    return image, image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(batch_size*4)
ds_train = ds_train.batch(batch_size).prefetch(batch_size)

ds_test = ds_test.map(preprocess).batch(batch_size).prefetch(batch_size)

train_num = ds_info.splits['train'].num_examples
test_num = ds_info.splits['test'].num_examples

In [7]:
class GaussianSampling(keras.layers.Layer):
    def call(self, inputs):
        means, logvar = inputs
        epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)
        samples = means + tf.exp(0.5*logvar)*epsilon
        return samples

In [9]:
class DownConvBlock(layers.Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):
        super(DownConvBlock, self).__init__(name=f'DownConvBlock_{DownConvBlock.count}')
        DownConvBlock.count += 1
        self.forward = keras.Sequential([layers.Conv2D(filters, kernel_size, strides, padding)])
        self.forward.add(layers.BatchNormalization())
        self.forward.add(layers.LeakyReLU(0.2))
        
    def call(self, inputs):
        return self.forward(inputs)

In [10]:
class UpConvBlock(layers.Layer):
    count = 0
    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):
        super(UpConvBlock, self).__init__(name=f'UpConvBlock_{UpConvBlock.count}')
        UpConvBlock.count += 1
        self.forward = keras.Sequential([layers.Conv2D(filters, kernel_size, strides, padding)])
        self.forward.add(layers.LeakyReLU(0.2))
        self.forward.add(layers.UpSampling2D((2,2)))
        
    def call(self, inputs):
        return self.forward(inputs)

In [11]:
class Encoder(layers.Layer):
    def __init__(self, z_dim, name='encoder'):
        super(Encoder, self).__init__(name.name)
        self.features_extract = keras.Sequential([
            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
            layers.Flatten()])
        self.dense_mean = layers.Dense(z_dim, name='mean')
        self.dense_logvar = layers.Dense(z_dim, name='logvar')
        self.sampler = GaussianSampling()
    def call(self, inputs):
        x = self.features_extract(inputs)
        mean = self.dense_mean(x)
        logvar = self.dense_logvar(x)
        z = self.sampler([mean, logvar])
        return z, mean, logvar

In [12]:
class Decoder(layers.Layer):
    def __init__(self, z_dim, name='decoder'):
        super(Decoder, self).__init__(name=name)
        self.forward = keras.Sequential([
            layers.Dense(7*7*64, activation='relu'),
            layers.Reshape((7,7,64)),
            UpConvBlock(filters=64, kernel_size=(3,3)),
            UpConvBlock(filters=64, kernel_size=(3,3)),
            UpConvBlock(filters=32, kernel_size=(3,3)),
            UpConvBlock(filters=32, kernel_size=(3,3)),
            layers.Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid'),
        ])
    def call(self, inputs):
        return self.forward(inputs)

In [13]:
class VAE(keras.Model):
    def __init__(self, z_dim, name='VAE'):
        super(VAE, self).__init__(name=name)
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
        self.mean = None
        self.logvar = None
        
    def call(self, inputs):
        z, self.mean, self.logvar = self.encoder(inputs)
        out = self.decoder(z)
        return out

In [14]:
def vae_kl_loss(y_true, y_pred):
    kl_loss = - 0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))
    return kl_loss

In [15]:
def vae_rc_loss(y_true, y_pred):
    rc_loss = tf.keras.losses.MSE(y_true, y_pred)
    return rc_loss

In [16]:
def vae_loss(y_true, y_pred):
    kl_loss = vae_kl_loss(y_true, y_pred)
    rc_loss = vae_rc_loss(y_true, y_pred)
    kl_weight_const = 0.01
    return kl_weight_const * kl_loss + rc_loss