In [12]:
import keras


In [13]:
import keras.backend as K
K.clear_session()

Variational autoencoders
Variational autoencoders, simultaneously discovered by Kingma & Welling in December 2013, and Rezende, Mohamed & Wierstra in January 2014, are a kind of generative model that is especially appropriate for the task of image editing via concept vectors. They are a modern take on autoencoders -- a type of network that aims to "encode" an input to a low-dimensional latent space then "decode" it back -- that mixes ideas from deep learning with Bayesian inference.

A classical image autoencoder takes an image, maps it to a latent vector space via an "encoder" module, then decode it back to an output with the same dimensions as the original image, via a "decoder" module. It is then trained by using as target data the same images as the input images, meaning that the autoencoder learns to reconstruct the original inputs. By imposing various constraints on the "code", i.e. the output of the encoder, one can get the autoencoder to learn more or less interesting latent representations of the data. Most commonly, one would constraint the code to be very low-dimensional and sparse (i.e. mostly zeros), in which case the encoder acts as a way to compress the input data into fewer bits of information.

Autoencoder

In practice, such classical autoencoders don't lead to particularly useful or well-structured latent spaces. They're not particularly good at compression, either. For these reasons, they have largely fallen out of fashion over the past years. Variational autoencoders, however, augment autoencoders with a little bit of statistical magic that forces them to learn continuous, highly structured latent spaces. They have turned out to be a very powerful tool for image generation.

A VAE, instead of compressing its input image into a fixed "code" in the latent space, turns the image into the parameters of a statistical distribution: a mean and a variance. Essentially, this means that we are assuming that the input image has been generated by a statistical process, and that the randomness of this process should be taken into accounting during encoding and decoding. The VAE then uses the mean and variance parameters to randomly sample one element of the distribution, and decodes that element back to the original input. The stochasticity of this process improves robustness and forces the latent space to encode meaningful representations everywhere, i.e. every point sampled in the latent will be decoded to a valid output.

In [14]:
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np

img_shape=(28,28,1)
batch_size=16
latent_dim=2

input_img=keras.Input(shape=img_shape)

x=layers.Conv2D(32,3,
               padding='same',activation='relu')(input_img)
x=layers.Conv2D(64,3,
               padding='same',activation='relu',
               strides=(2,2))(x)
x=layers.Conv2D(64,3,
               padding='same',activation='relu')(x)
x=layers.Conv2D(64,3,
               padding='same',activation='relu')(x)


shape_before_flattening=K.int_shape(x)

In [15]:
x=layers.Flatten()(x)

x=layers.Dense(32,activation='relu')(x)

z_mean=layers.Dense(latent_dim)(x)
z_log_var=layers.Dense(latent_dim)(x)



In [17]:
def sampling(args):
    z_mean,z_log_var=args
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),
                            mean=0.,stddev=1.)
    return z_mean+K.exp(z_log_var)*epsilon

z=layers.Lambda(sampling)([z_mean,z_log_var])

In [20]:
decoder_input=layers.Input(K.int_shape(z)[1:])
x=layers.Dense(np.prod(shape_before_flattening[1:]),
              activation='relu')(decoder_input)

x=layers.Reshape(shape_before_flattening[1:])(x)

x=layers.Convolution2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)

x=layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)

decoder=Model(decoder_input,x)
z_decoded=decoder(z)

In [21]:
class CustomVariationalLayer(keras.layers.Layer):
    
    def vae_loss(self,x,z_decoded):
        x=K.flatten(x)
        z_decoded=K.flatten(z_decoded)
        xent_loss=keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss=-5e-4 * K.mean(
        1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss+kl_loss)
    
    def call(self,inputs):
        x=inputs[0]
        z_decoded=inputs[1]
        loss=self.vae_loss(x,z_decoded)
        self.add_loss(loss,inputs=inputs)
        
        return x
    
    
    
y=CustomVariationalLayer()([input_img,z_decoded])





In [22]:
from keras.datasets import mnist

vae=Model(input_img,y)


In [None]:
vae.compile(optimizer='rmsprop',loss=None)
vae.summary()

(x_train,_),(x_test,y_test)=mnist.load_data()
x_train=x_train.astype('float32')/255.
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.astype('float32')/255.
x_test=