In [1]:
import os
import sys
import tensorflow
import numpy as np

import matplotlib
# matplotlib.use('TKAgg')
from matplotlib import pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
mnist_image_shape = [28, 28, 1]
batch_size = 100
batch_shape = (batch_size, 28, 28, 1)
num_visualize = 10

lr = 0.01
num_epochs = 50

In [2]:
def get_deconv2d_output_dims(input_dims, filter_dims, stride_dims, padding):
    # Returns the height and width of the output of a deconvolution layer.
    batch_size, input_h, input_w, num_channels_in = input_dims
    filter_h, filter_w, num_channels_out  = filter_dims
    stride_h, stride_w = stride_dims

    # Compute the height in the output, based on the padding.
    if padding == 'SAME':
      out_h = input_h * stride_h
    elif padding == 'VALID':
      out_h = (input_h - 1) * stride_h + filter_h

    # Compute the width in the output, based on the padding.
    if padding == 'SAME':
      out_w = input_w * stride_w
    elif padding == 'VALID':
      out_w = (input_w - 1) * stride_w + filter_w

    return [batch_size, out_h, out_w, num_channels_out]

In [None]:
def load_dataset():
    return input_data.read_data_sets('MNIST_data')

def get_next_batch(dataset, batch_size):
    # dataset should be mnist.(train/val/test)
    batch, _ = dataset.next_batch(batch_size)
    batch_shape = [batch_size] + mnist_image_shape
    return np.reshape(batch, batch_shape)

def visualize(_original, _reconstructions, num_visualize):
    vis_folder = './vis/'
    if not os.path.exists(vis_folder):
          os.makedirs(vis_folder)

    original = _original[:num_visualize]
    reconstructions = _reconstructions[:num_visualize]
    
    count = 1
    for (orig, rec) in zip(original, reconstructions):
        orig = np.reshape(orig, (mnist_image_shape[0],
                                 mnist_image_shape[1]))
        rec = np.reshape(rec, (mnist_image_shape[0],
                               mnist_image_shape[1]))
        f, ax = plt.subplots(1,2)
        ax[0].imshow(orig, cmap='gray')
        ax[1].imshow(rec, cmap='gray')
        plt.savefig(vis_folder + "test_%d.png" % count)
        count += 1

In [3]:
def conv(input, name, filter_dims, stride_dims, padding='SAME',
         non_linear_fn=tf.nn.relu):
    input_dims = input.get_shape().as_list()
    assert(len(input_dims) == 4) # batch_size, height, width, num_channels_in
    assert(len(filter_dims) == 3) # height, width and num_channels out
    assert(len(stride_dims) == 2) # stride height and width

    num_channels_in = input_dims[-1]
    filter_h, filter_w, num_channels_out = filter_dims
    stride_h, stride_w = stride_dims

    # Define a variable scope for the conv layer
    with tf.variable_scope(name) as scope:
        # Create filter weight variable
        
        # Create bias variable
        
        # Define the convolution flow graph
        
        # Add bias to conv output
        
        # Apply non-linearity (if asked) and return output
        pass

def deconv(input, name, filter_dims, stride_dims, padding='SAME',
           non_linear_fn=tf.nn.relu):
    input_dims = input.get_shape().as_list()
    assert(len(input_dims) == 4) # batch_size, height, width, num_channels_in
    assert(len(filter_dims) == 3) # height, width and num_channels out
    assert(len(stride_dims) == 2) # stride height and width

    num_channels_in = input_dims[-1]
    filter_h, filter_w, num_channels_out = filter_dims
    stride_h, stride_w = stride_dims
    # Let's step into this function
    output_dims = get_deconv2d_output_dims(input_dims,
                                           filter_dims,
                                           stride_dims,
                                           padding)

    # Define a variable scope for the deconv layer
    with tf.variable_scope(name) as scope:
        # Create filter weight variable
        # Note that num_channels_out and in positions are flipped for deconv.
        
        # Create bias variable
        
        # Define the deconv flow graph
        
        # Add bias to deconv output
        
        # Apply non-linearity (if asked) and return output
        pass

def max_pool(input, name, filter_dims, stride_dims, padding='SAME'):
    assert(len(filter_dims) == 2) # filter height and width
    assert(len(stride_dims) == 2) # stride height and width

    filter_h, filter_w = filter_dims
    stride_h, stride_w = stride_dims
    
    # Define the max pool flow graph and return output
    pass

def fc(input, name, out_dim, non_linear_fn=tf.nn.relu):
    assert(type(out_dim) == int)

    # Define a variable scope for the FC layer
    with tf.variable_scope(name) as scope:
        input_dims = input.get_shape().as_list()
        # the input to the fc layer should be flattened
        if len(input_dims) == 4:
            # for eg. the output of a conv layer
            batch_size, input_h, input_w, num_channels = input_dims
            # ignore the batch dimension
            in_dim = input_h * input_w * num_channels
            flat_input = tf.reshape(input, [batch_size, in_dim])
        else:
            in_dim = input_dims[-1]
            flat_input = input

        # Create weight variable
        
        # Create bias variable
        
        # Define FC flow graph
        
        # Apply non-linearity (if asked) and return output
        pass

In [None]:
def calculate_loss(original, reconstructed):
    return tf.div(tf.reduce_sum(tf.square(tf.sub(reconstructed,
                                                 original))), 
                  tf.constant(float(batch_size)))

def train(dataset):
    input_image, reconstructed_image = autoencoder(batch_shape)
    loss = calculate_loss(input_image, reconstructed_image)
    optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)

    init = tf.global_variables_initializer()
    with tf.Session() as session:
        session.run(init)

        dataset_size = len(dataset.train.images)
        print "Dataset size:", dataset_size
        num_iters = (num_epochs * dataset_size)/batch_size
        print "Num iters:", num_iters
        for step in xrange(num_iters):
            input_batch  = get_next_batch(dataset.train, batch_size)
            loss_val,  _ = session.run([loss, optimizer], 
                                       feed_dict={input_image: input_batch})
            if step % 1000 == 0:
                print "Loss at step", step, ":", loss_val

        test_batch = get_next_batch(dataset.test, batch_size)
        reconstruction = session.run(reconstructed_image,
                                     feed_dict={input_image: test_batch})
        visualize(test_batch, reconstruction, num_visualize)

if __name__ == '__main__':
    dataset = load_dataset()
    train(dataset)