In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

tfk = tf.keras
tfl = tfk.layers
tfd = tfp.distributions
tfs = tf.summary

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm_notebook as tqdm

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
batch_size = 64
buffer_size = 500

def normalise_image(img):
    return tf.cast(img, tf.float32) / 255.

train_dataset = tfds.load(name="cifar10", split=tfds.Split.ALL)
train_dataset = train_dataset.shuffle(buffer_size=buffer_size).repeat()
train_dataset = train_dataset.map(lambda x: x['image'])
train_dataset = train_dataset.map(normalise_image)
train_dataset = train_dataset.batch(batch_size)

In [3]:
class Encoder(tfl.Layer):
    
    def __init__(self, latent_size, name='encoder', **kwargs):
        
        self.latent_size = latent_size
        
        super(Encoder, self).__init__(name=name, **kwargs)
        
    
    def build(self, input_shape):
        
        self.layers = [
            
            tfl.Conv2D(filters=64,
                       kernel_size=(5, 5),
                       padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2D(filters=128,
                       kernel_size=(5, 5),
                       strides=(2, 2),
                       padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2D(filters=256,
                       kernel_size=(5, 5),
                       strides=(2, 2),
                       padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2D(filters=512,
                       kernel_size=(5, 5),
                       strides=(2, 2),
                       padding='same'),
            
            tf.nn.relu,
            
            tfl.Reshape((4 * 4 * 512,)),
            
            tfl.Dense(512),
            
            tf.nn.relu
        ]
        
        self.loc_head = tfl.Dense(self.latent_size)
        
        self.log_scale_head = tfl.Dense(self.latent_size)
    
    
    def call(self, tensor):
        
        for layer in self.layers:
            tensor = layer(tensor)
            
        loc = self.loc_head(tensor)
        scale = tf.nn.softplus(self.log_scale_head(tensor))
    
        self.posterior = tfd.Normal(loc=loc, scale=scale)
        
        return self.posterior.sample()
    
    
class Decoder(tfl.Layer):
    
    def __init__(self, name='decoder', **kwargs):
        
        super(Decoder, self).__init__(name=name, **kwargs)
        
    
    def build(self, input_shape):
        
        self.layers = [
            
            tfl.Dense(512),
            
            tf.nn.relu,
            
            tfl.Dense(4 * 4 * 512),
            
            tf.nn.relu,
            
            tfl.Reshape((4, 4, 512)),
            
            tfl.Conv2DTranspose(filters=256,
                                kernel_size=(5, 5),
                                strides=(2, 2),
                                padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2DTranspose(filters=128,
                                kernel_size=(5, 5),
                                strides=(2, 2),
                                padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2DTranspose(filters=64,
                                kernel_size=(5, 5),
                                strides=(2, 2),
                                padding='same'),
            
            tf.nn.relu,
            
            tfl.Conv2DTranspose(filters=3,
                                kernel_size=(5, 5),
                                padding='same'),
            
            tf.nn.sigmoid
            
        ]
    

    def call(self, tensor):
        
        for layer in self.layers:
            tensor = layer(tensor)
        
        return tensor
    
    
class VAE(tfk.Model):
    
    def __init__(self, latent_size, name='vae', **kwargs):
        
        self.latent_size = latent_size
        
        self.log_noise = tf.Variable(0.0)
        
        super(VAE, self).__init__(name=name, **kwargs)
        
        
    @property
    def log_prob(self):
        
        return tf.reduce_mean(self._log_prob)
    
    @property
    def kl_divergence(self):
        
        kl_each_latent = tfd.kl_divergence(self.posterior, self.prior)
        
        kl_each_example = tf.reduce_sum(kl_each_latent, axis=-1)
        
        return tf.reduce_mean(kl_each_example)
    
    @property
    def posterior(self):
        
        return self.encoder.posterior
    
    
    def build(self, input_shape):
        
        self.encoder = Encoder(self.latent_size)
        self.decoder = Decoder()

        self.prior = tfd.Normal(loc=tf.zeros(self.latent_size),
                                scale=tf.ones(self.latent_size))
    
    
    def call(self, tensor):
        
        latents = self.encoder(tensor)
        
        loc = self.decoder(latents)
        
        scale = tf.exp(self.log_noise)
        
        self.likelihood = tfd.Normal(loc=loc, scale=scale)
        self._log_prob = tf.reduce_sum(self.likelihood.log_prob(tensor), axis=(1, 2, 3))
        
        return loc

In [None]:
train_steps = int(1e6)
beta = 1e0
learn_rate = 1e-3
log_freq = 10

vae = VAE(64)

optimizer = tfk.optimizers.Adam(learn_rate)

train_summary_writer = tfs.create_file_writer('summaries/train/cifar10')

with train_summary_writer.as_default():

    for batch in tqdm(train_dataset.take(train_steps), total=train_steps):
        
        with tf.GradientTape() as tape:
            
            reconstructions = vae(batch)
            
            likelihood = vae.log_prob
            
            kl_divergence = tf.reduce_sum(vae.kl_divergence)

            neg_elbo = - likelihood + beta * tf.reduce_sum(vae.kl_divergence)

        gradients = tape.gradient(neg_elbo, vae.trainable_variables)

        optimizer.apply_gradients(zip(gradients, vae.trainable_variables))

        if tf.equal(optimizer.iterations % log_freq, 0):

            tfs.scalar('ELBO', - neg_elbo, step=optimizer.iterations)
            tfs.scalar('Likelihood', likelihood, step=optimizer.iterations)
            tfs.scalar('KL-divergence', kl_divergence, step=optimizer.iterations)
            tfs.scalar('log-noise', vae.log_noise, step=optimizer.iterations)
            tfs.image('Original', batch, step=optimizer.iterations)
            tfs.image('Reconstruction', reconstructions, step=optimizer.iterations)

HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))