# Teaching a Variational Autoencoder (VAE) to draw MNIST characters

Autoencoders are a type of neural network that can be used to learn efficient codings of input data. 
Given some inputs, the network firstly applies a series of transformations that map the input data into a lower dimensional space. This part of the network is called the _encoder_. Then, the network uses the encoded data to try and recreate the inputs. This part of the network is the _decoder_. Using the encoder, we can later compress data of the type that is understood by the network. However, autoencoders are rarely used for this purpose, as usually there exist hand-crafted algorithms (like _jpg_-compression) that are more efficient. Instead, autoencoders have repeatedly been applied to perform denoising tasks. Then, the encoder receives pictures that have been tampered with noise, and it learns how to reconstruct the original images.


## Variational Autoencoders put simply
But there exists a much more interesting application for autoencoders. This application is called the _variational autoencoder_. Using variational autoencoders, it's not only possible to compress data -- it's also possible to generate new objects of the type the autoencoder has seen before.

Using a general autoencoder, we don't know anything about the coding that's been generated by our network. We could take a look at and compare different encoded objects, but it's unlikely that we'll be able to understand what's going on. This means that we won't be able to use our decoder for creating new objects -- we simply don't know what the inputs should look like.

Using a variational autoencoder, we take the opposite approach instead. We will not try to make guesses concerning the distribution that's being followed by the latent vectors. We simply tell our network what we want this distribution to look like. Usually, we will constrain the network to produce latent vectors having entries that follow the unit normal distribution. Then, when trying to generate data, we can simply sample some values from this distribution, feed them to the decoder, and the decoder will return us completely new objects that appear just like the objects our network has been trained with.

Let's see how this can be done using python and tensorflow. We are going to teach our network how to draw MNIST characters.

## First steps -- Loading the training data
Firstly, we perform some basic imports. Tensorflow has a quite handy function that allows us to easily access the MNIST data set.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# from tensorflow.compat.v1.examples.tutorials.mnist import input_data
TRAIN_BUF = 60000
TEST_BUF = 10000

BATCH_SIZE = 100

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(y_train).shuffle(TEST_BUF).batch(BATCH_SIZE)

In [2]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

## Defining our input and output data
MNIST images have a dimension of 28 * 28 pixels with one color channel. Our inputs _X_in_ will be batches of MNIST characters, while our network will learn to reconstruct them and output them in a placeholder _Y_, which thus has the same dimensions. _Y_flat_ will be used later, when computing losses. _keep_prob_ will be used when applying dropouts as a means of regularization. During training, it will have a value of 0.8. When generating new data, we won't apply dropout, so the value will be 1. The function _lrelu_ is being defined as tensorflow unfortunately doesn't come up with a predefined leaky ReLU.

In [3]:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()

batch_size = 64

X_in = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='X')
Y    = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28], name='Y')
Y_flat = tf.reshape(Y, shape=[-1, 28 * 28])
keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, shape=(), name='keep_prob')

dec_in_channels = 1
n_latent = 8

reshaped_dim = [-1, 7, 7, dec_in_channels]
inputs_decoder = 49 * dec_in_channels / 2


def lrelu(x, alpha=0.3):
    return tf.maximum(x, tf.multiply(x, alpha))

## Defining the encoder
As our inputs are images, it's most reasonable to apply some convolutional transformations to them. What's most noteworthy is the fact that we are creating two vectors in our encoder, as the encoder is supposed to create objects following a Gaussian Distribution:
* A vector of means
* A vector of standard deviations

You will see later how we "force" the encoder to make sure it really creates values following a Normal Distribution. The returned values that will be fed to the decoder are the _z_-values. We will need the mean and standard deviation of our distributions later, when computing losses. 

In [4]:
def encoder(X_in, keep_prob):
    activation = lrelu
    with tf.compat.v1.variable_scope("encoder", reuse=None):
        X = tf.reshape(X_in, shape=[-1, 28, 28, 1])
        x = tf.compat.v1.layers.conv2d(X, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)
        x = tf.nn.dropout(x, keep_prob)
        x = tf.compat.v1.layers.conv2d(x, filters=64, kernel_size=4, strides=2, padding='same', activation=activation)
        x = tf.nn.dropout(x, keep_prob)
        x = tf.compat.v1.layers.conv2d(x, filters=64, kernel_size=4, strides=1, padding='same', activation=activation)
        x = tf.nn.dropout(x, keep_prob)
        x = tf.compat.v1.layers.flatten(x)
        mn = tf.compat.v1.layers.dense(x, units=n_latent)
        sd       = 0.5 * tf.compat.v1.layers.dense(x, units=n_latent)            
        epsilon = tf.compat.v1.random_normal(tf.stack([tf.shape(x)[0], n_latent])) 
        z  = mn + tf.multiply(epsilon, tf.exp(sd))
        
        return z, mn, sd

## Defining the decoder
The decoder does not care about whether the input values are sampled from some specific distribution that has been defined by us. It simply will try to reconstruct the input images. To this end, we use a series of transpose convolutions.

In [5]:
def decoder(sampled_z, keep_prob):
    with tf.compat.v1.variable_scope("decoder", reuse=None):
        x = tf.compat.v1.layers.dense(sampled_z, units=inputs_decoder, activation=lrelu)
        x = tf.compat.v1.layers.dense(x, units=inputs_decoder * 2 + 1, activation=lrelu)
        x = tf.reshape(x, reshaped_dim)
        x = tf.compat.v1.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=2, padding='same', activation=tf.nn.relu)
        x = tf.nn.dropout(x, keep_prob)
        x = tf.compat.v1.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)
        x = tf.nn.dropout(x, keep_prob)
        x = tf.compat.v1.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, padding='same', activation=tf.nn.relu)
        
        x = tf.compat.v1.layers.flatten(x)
        x = tf.compat.v1.layers.dense(x, units=28*28, activation=tf.nn.sigmoid)
        img = tf.reshape(x, shape=[-1, 28, 28])
        return img

Now, we'll wire together both parts:

In [6]:
sampled, mn, sd = encoder(X_in, keep_prob)
dec = decoder(sampled, keep_prob)

Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Use `tf.keras.layers.Conv2DTranspose` instead.


## Computing losses and enforcing a Gaussian latent distribution
For computing the image reconstruction loss, we simply use squared difference (which could lead to images sometimes looking a bit fuzzy). This loss is combined with the _Kullback-Leibler divergence_, which makes sure our latent values will be sampled from a normal distribution. For more on this topic, please take a look a Jaan Altosaar's great article on VAEs. 

In [7]:
unreshaped = tf.reshape(dec, [-1, 28*28])
img_loss = tf.reduce_sum(tf.compat.v1.squared_difference(unreshaped, Y_flat), 1)
latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
loss = tf.reduce_mean(img_loss + latent_loss)
optimizer = tf.compat.v1.train.AdamOptimizer(0.0005).minimize(loss)
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())

## Training the network
Now, we can finally train our VAE! Every 200 steps, we'll take a look at what the current reconstructions look like. After having processed about 2000 batches, most reconstructions will look reasonable.

In [8]:
for i in range(30000):
    batch = [np.reshape(b, [28, 28]) for b in x_train[np.random.randint(x_train.shape[0], size=64), :,:]]
    sess.run(optimizer, feed_dict = {X_in: batch, Y: batch, keep_prob: 0.8})
        
    if not i % 200:
        ls, d, i_ls, d_ls, mu, sigm = sess.run([loss, dec, img_loss, latent_loss, mn, sd], feed_dict = {X_in: batch, Y: batch, keep_prob: 1.0})
        plt.imshow(np.reshape(batch[0], [28, 28]), cmap='gray')
        plt.show()
        plt.imshow(d[0], cmap='gray')
        plt.show()
        print(i, ls, np.mean(i_ls), np.mean(d_ls))

InvalidArgumentError: Input to reshape is a tensor with 3200 values, but the requested shape requires a multiple of 49
	 [[node decoder/Reshape (defined at /home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1751) ]]

Original stack trace for 'decoder/Reshape':
  File "/home/rens/anaconda3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/rens/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 505, in start
    self.io_loop.start()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 148, in start
    self.asyncio_loop.run_forever()
  File "/home/rens/anaconda3/lib/python3.7/asyncio/base_events.py", line 539, in run_forever
    self._run_once()
  File "/home/rens/anaconda3/lib/python3.7/asyncio/base_events.py", line 1775, in _run_once
    handle._run()
  File "/home/rens/anaconda3/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 787, in inner
    self.run()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 378, in dispatch_queue
    yield self.process_one()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 225, in wrapper
    runner = Runner(result, future, yielded)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 714, in __init__
    self.run()
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 272, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 542, in execute_request
    user_expressions, allow_stdin,
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2854, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2880, in _run_cell
    return runner(coro)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3057, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3248, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/rens/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3325, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-b7ad93d46d0a>", line 2, in <module>
    dec = decoder(sampled, keep_prob)
  File "<ipython-input-5-18566e0f24d6>", line 5, in decoder
    x = tf.reshape(x, reshaped_dim)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/array_ops.py", line 131, in reshape
    result = gen_array_ops.reshape(tensor, shape, name)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_array_ops.py", line 8117, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/op_def_library.py", line 793, in _apply_op_helper
    op_def=op_def)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 3360, in create_op
    attrs, op_def, compute_device)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 3429, in _create_op_internal
    op_def=op_def)
  File "/home/rens/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1751, in __init__
    self._traceback = tf_stack.extract_stack()


## Generating new data
The most awesome part is that we are now able to create new characters. To this end, we simply sample values from a unit normal distribution and feed them to our decoder. Most of the created characters look just like they've been written by humans.  

In [None]:
randoms = [np.random.normal(0, 1, n_latent) for _ in range(10)]
imgs = sess.run(dec, feed_dict = {sampled: randoms, keep_prob: 1.0})
imgs = [np.reshape(imgs[i], [28, 28]) for i in range(len(imgs))]

for img in imgs:
    plt.figure(figsize=(1,1))
    plt.axis('off')
    plt.imshow(img, cmap='gray')

## Conclusion
Now, this obviously is a relatively simple example of an application of VAEs. But just think about what could be possible! Neural networks could learn to compose music. They could automatically create illustrations for books, games etc. With a bit of creativity, VAEs will open up space for some awesome projects 

In [9]:
batch

[array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   2,  88, 214, 254,  60,   0,   0,   0,   0,   0,
           0