This is a re-implementation of the "vanilla_autoencoder" by CNN architecture.

Follows the blog: http://machinelearninguru.com/deep_learning/tensorflow/neural_networks/autoencoder/autoencoder.html
(and its [github-page](https://github.com/Machinelearninguru/Deep_Learning/blob/master/TensorFlow/neural_networks/autoencoder/simple_autoencoder.py))

In [9]:
%matplotlib inline

from IPython.display import display
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from skimage import transform

In [31]:
def encoder(observable, name='encoder', reuse=None):
  """
  Args:
    observable: Tensor of the shape `[batch_size, 32, 32, 1]`.
      
  Returns:
    Tensor of the shape `[batch_size, 2, 2, 8]`.
  """
  with tf.variable_scope(name, reuse=reuse):
    # 32 x 32 x 1  ->  16 x 16 x 32
    hidden = tf.layers.conv2d(observable, 32, [5, 5], strides=2, padding='SAME')
    # 16 x 16 x 32  ->  8 x 8 x 16
    hidden = tf.layers.conv2d(hidden, 16, [5, 5], strides=2, padding='SAME')
    # 8 x 8 x 16  ->  2 x 2 x 8
    latent = tf.layers.conv2d(hidden, 8, [5, 5], strides=4, padding='SAME')
    return latent

In [28]:
def decoder(latent, name='decoder', reuse=None):
  """
  Args:
    latent: Tensor of the shape `[batch_size, 2, 2, 8]`.
  
  Returns:
    Tensor of the shape `[batch_size, 32, 32, 1]`.
  """
  with tf.variable_scope(name, reuse=reuse):
    # 2 x 2 x 8  ->  8 x 8 x 16
    hidden = tf.layers.conv2d_transpose(latent, 16, [5, 5], strides=4, padding='SAME')
    # 8 x 8 x 16  ->  16 x 16 x 32
    hidden = tf.layers.conv2d_transpose(hidden, 32, [5, 5], strides=2, padding='SAME')
    # 16 x 16 x 32  ->  32 x 32 x 1
    observable = tf.layers.conv2d_transpose(hidden, 1, [5, 5], strides=2, padding='SAME',
                                            activation=tf.nn.tanh)
    return observable

In [7]:
def resize_batch(images):
    """A function to resize a batch of MNIST images to (32, 32).
    
    Args:
        images: Numpy array of the shape `[batch_size, 28 * 28]`.
    Returns:
        Numpy array of the shape `[batch_size, 32, 32]`.
    """
    images = images.reshape((-1, 28, 28, 1))
    resized_images = np.zeros((images.shape[0], 32, 32, 1))
    for i in range(images.shape[0]):
        resized_images[i, ..., 0] = transform.resize(images[i, ..., 0], (32, 32))
    return resized_images

In [33]:
def get_loss(observable, encoder, decoder, regularizer=None, reuse=None):
  if regularizer is None:
    regularizer = lambda latent: 0.0
    
  with tf.name_scope('loss'):
    # shape: [batch_size] + latent_dims
    latent = encoder(observable, reuse=reuse)
    # shape: [batch_size] + observable_dims
    reconstructed = decoder(latent, reuse=reuse)
    # shape: [batch_size]
    squared_errors = tf.reduce_sum(
        tf.layers.flatten((reconstructed - observable) ** 2),
        axis=1)
    mean_square_error = tf.reduce_mean(squared_errors)
    return mean_square_error + regularizer(latent)

In [29]:
observable = tf.placeholder(shape=[None, 32, 32, 1],
                            dtype='float32',
                            name='observable')
latent_samples = tf.placeholder(shape=[None, 2, 2, 8],
                                dtype='float32',
                                name='latent_samples')
generated = decoder(latent_samples, reuse=tf.AUTO_REUSE)

In [10]:
data_path = '../../dat/MNIST/'
mnist = input_data.read_data_sets(
    data_path, one_hot=True,
    source_url='http://yann.lecun.com/exdb/mnist/')

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../dat/MNIST/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../dat/MNIST/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../dat/MNIST/t10k-images-idx3-ubyte.gz
Extracting ../../dat/MNIST/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [19]:
def get_X_batch(batch_size, mnist=mnist, source='train'):
    """Returns the X-batch as a numpy array of the shape
    `[batch_size, 32, 32, 1]`."""
    if source == 'train':
        dataset = mnist.train
    elif source == 'test':
        dataset = mnist.test
    else:
        raise ValueError('Argument `source` can either be "train" or '
                         '"test". But now it is {}.'.format(source))
        
    X_batch, y_batch = mnist.train.next_batch(batch_size)
    X_batch = X_batch.reshape([-1, 28, 28, 1])
    X_batch = resize_batch(X_batch)
    return X_batch

In [34]:
def regularizer(latent, name='regularizer'):
  with tf.name_scope(name):
    latent = tf.layers.flatten(latent)
    distances = tf.reduce_sum(latent ** 2, axis=1)
    return tf.reduce_mean(distances)

In [35]:
loss = get_loss(observable, encoder, decoder,
                regularizer=regularizer,
                reuse=tf.AUTO_REUSE)

In [36]:
optimizer = tf.train.AdamOptimizer(epsilon=1e-3)
train_op = optimizer.minimize(loss)

In [37]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [39]:
batch_size = 128

loss_vals = []
for i in tqdm(range(100000)):
  X_batch = get_X_batch(batch_size)
  _, loss_val = sess.run([train_op, loss], {observable: X_batch})
  if np.isnan(loss_Xy_val):
    raise ValueError('Loss has been NaN.')
  loss_vals.append(loss_val)

print('Final loss:', np.mean(loss_vals[-100:]))

plt.plot(loss_vals)
plt.xlabel('steps')
plt.ylabel('loss')
plt.show()

  0%|          | 345/100000 [01:19<6:21:12,  4.36it/s]

KeyboardInterrupt: 

In [None]:
def get_image(array):
  """
  Args:
    array: Numpy array with shape `[32, 32, 1]`.
    
  Returns:
    An image.
  """
  array = 255 * array
  array = np.squeeze(array, axis=-1)
  array = array.astype(np.uint8)
  return Image.fromarray(array)

In [None]:
latent_sample_vals = np.random.normal(size=[128, 2, 2, 8])
generated_vals = sess.run(generated, {latent_samples: latent_sample_vals})

# Display the results
n_display = 5
for i in range(n_display):
  print('Gnerated:')
  display(get_image(generated_vals[i]))
  print()