# DC-GAN on MNIST dataset
- [Github](https://github.com/zsdonghao/dcgan) repo
- [Blog post](http://bamos.github.io/2016/08/09/deep-completion/) tutorial
- [Paper](https://arxiv.org/pdf/1511.06434.pdf)

## Architecture guidelines for stable Deep Convolutional GANs
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided
convolutions (generator).
- Use batchnorm in both the generator and the discriminator.
- Remove fully connected hidden layers for deeper architectures.
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
- Use LeakyReLU activation in the discriminator for all layers.

### Notes
The inverse of convolution is transposed convolution.  It can be [shown](https://arxiv.org/abs/1603.07285) that transpose convolutions can be computed using the same kernal as the non-transposed convolution by manipulating the input space (i.e. adding or subtracting zero-padding).  

More specific to this application, the GAN architecture is trying to pretend the noise input is the high-level feature representations and wants to reverse a normal CNN architecture by going backwards through all the convolutions.  Normal CNNs move a kernal across an input to reduce dimensionality and we want our GAN to do the opposite:  to take high-level features and increase dimensionality to recover the image.  This is acheived by "fractional striding".  

A stride of 1 means the convolution moves one-pixel at a time.  A stride of 2 means the kernal skips a pixel and makes jumps of 2 at a time.  A fractional stride adds zeros between the pixels to effectively go slower than a stride of 1.

![title](img/gan-architecture.png)

In [4]:
import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer, batch_norm, flatten, fully_connected
from tensorflow.examples.tutorials.mnist import input_data

In [5]:
N_input = 100

In [6]:
def lrelu(x, leak=0.2, name="lrelu"):
    with tf.variable_scope(name):
        f1 = 0.5 * (1 + leak)
        f2 = 0.5 * (1 - leak)
        return f1 * x + f2 * abs(x)

In [7]:
def generator(z):
    # Generator Net
    # Input noise
    with tf.variable_scope("g_conv1") as scope:
        # Top level Conv Filters
        # Use matrix multiplication and reshape into stack of 7x7x512 filters
        G_W1 = tf.get_variable("G_W1", shape=[N_input, 7*7*512], initializer=xavier_initializer())
        G_b1 = tf.Variable(tf.zeros(shape=[7*7*512]), name='G_b1')
        z_ = tf.matmul(z, G_W1) + G_b1
        h0 = tf.reshape(z_, [-1, 7, 7, 512]) # 1st dimension is number of samples
        a0 = batch_norm(h0, activation_fn=tf.nn.relu)
    
    with tf.variable_scope("g_conv2") as scope:
        # Transpose-conv 1 (14x14x256)
        t_conv1 = tf.layers.conv2d_transpose(
            a0,
            filters=256,
            kernel_size=(5,5),
            strides=(2, 2),
            padding='same',
        )
        t_conv1_a = batch_norm(t_conv1, activation_fn=tf.nn.relu)

    with tf.variable_scope("g_conv3") as scope:
        # Transpose-conv 2 (28x28x128)
        t_conv2 = tf.layers.conv2d_transpose(
            t_conv1_a,
            filters=128,
            kernel_size=(5,5),
            strides=(2, 2),
            padding='same',
        )
        t_conv2_a = batch_norm(t_conv2, activation_fn=tf.nn.relu)

    with tf.variable_scope("g_out") as scope:
        # Transpose-conv 4 (28x28x1)
        g_out = tf.layers.conv2d_transpose(
            t_conv2_a,
            filters=1,
            kernel_size=(5,5),
            strides=(1, 1),
            padding='same',
            activation=tf.nn.tanh,
        )
    
    return g_out

In [8]:
def discriminator(x, reuse=False):
    # Discriminator Net
    # Input noise
    with tf.variable_scope("d_conv1") as scope:
        # Conv 1 (28x28x128)
        conv1 = tf.layers.conv2d(
            x,
            filters=128,
            kernel_size=(5,5),
            strides=(1,1),
            padding='same',
            reuse=reuse,
        )
        conv1_a = batch_norm(conv1, activation_fn=lrelu, reuse=reuse, scope=scope)
        
    with tf.variable_scope("d_conv2") as scope:
        # Conv 2 (14x14x256)
        conv2 = tf.layers.conv2d(
            conv1_a,
            filters=256,
            kernel_size=(5,5),
            strides=(2,2),
            padding='same',
            reuse=reuse,
        )
        conv2_a = batch_norm(conv2, activation_fn=lrelu, reuse=reuse, scope=scope)
    
    with tf.variable_scope("d_conv3") as scope:
        # Conv 3 (7x7x512)
        conv3 = tf.layers.conv2d(
            conv2_a,
            filters=512,
            kernel_size=(5,5),
            strides=(2,2),
            padding='same',
            reuse=reuse,
            #scope=scope
        )
        conv3_a = batch_norm(conv3, activation_fn=lrelu, reuse=reuse, scope=scope)
    
    with tf.variable_scope("d_out") as scope:    
        # Output (1)
        d_out = tf.nn.sigmoid(flatten(conv3_a))
    
    return d_out

In [9]:
tf.reset_default_graph()

Z = tf.placeholder(tf.float32, shape=[None, N_input], name='Z')
X = tf.placeholder(tf.float32, shape=[None,28,28,1], name='X')

Gz = generator(Z) #Generates images from random z vectors
Dx = discriminator(X)
Dg = discriminator(Gz, reuse=True)


#These functions together define the optimization objective of the GAN.
d_loss = -tf.reduce_mean(tf.log(Dx) + tf.log(1.-Dg)) #This optimizes the discriminator.
g_loss = -tf.reduce_mean(tf.log(Dg)) #This optimizes the generator.

tvars = tf.trainable_variables() # First 15 are for generator
theta_G = tvars[:12]
theta_D = tvars[12:]

#The below code is responsible for applying gradient descent to update the GAN.
# Only update D(X)'s parameters, so var_list = theta_D
D_solver = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(d_loss, var_list=theta_D)
# Only update G(X)'s parameters, so var_list = theta_G
G_solver = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(g_loss, var_list=theta_G)

In [10]:
mb_size = 128

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [11]:
dir_name = 'out_v1/'
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

In [12]:
model_dir_name = 'model_v1'
if not os.path.exists(model_dir_name):
    os.makedirs(model_dir_name)

In [13]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [14]:
def sample_Z(m, n):
    '''Uniform prior for G(Z)'''
    return np.random.uniform(-1., 1., size=[m, n])

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
i = 0
D_iter = 1 # Number of D iters per loop
G_iter = 1 # Number of G iters per loop

saver = tf.train.Saver()

for it in range(int(1e4)):
    if it % 1000 == 0:
        samples = sess.run(Gz, feed_dict={Z: sample_Z(16, N_input)})

        fig = plot(samples)
        plt.savefig(dir_name + '/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    
    for j in range(D_iter):
        X_mb, _ = mnist.train.next_batch(mb_size)
        X_ = (X_mb - .5) * 2 # scale to tanh range [-1, 1]
        _, D_loss_curr = sess.run(
            [D_solver, d_loss], 
            feed_dict={X: X_.reshape(mb_size,28,28,1), Z: sample_Z(mb_size, N_input)}
        )
    
    for k in range(G_iter):
        _, G_loss_curr = sess.run([G_solver, g_loss], feed_dict={Z: sample_Z(mb_size, N_input)})
        
    if it % 10000 == 0 and it != 0:
            saver.save(sess,model_dir_name+'/model-'+str(i)+'.cptk')
            print("Saved Model")
    

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
D loss: 1.509
G_loss: 0.5923

