# Generating new Objects w/ `Generative Adversarial Network (GAN)`

### Import dependencies

In [1]:
import os
import sys
import datetime as dt

import numpy as np
import tensorflow as tf

%matplotlib inline

In [2]:
from tensorflow.examples.tutorials.mnist import input_data

data_dir = 'datasets/MNIST/'
data = input_data.read_data_sets(data_dir, one_hot=True)

# from dataset import ImageDataset

# data_dir = 'datasets/101_ObjectCategories/'
# save_file = 'saved/data.pkl'

# data = ImageDataset(data_dir=data_dir, size=24, grayscale=True, flatten=True)
# data.create()
# data.save(save_file=save_file, force=True)
# # data = data.load(save_file=save_file)

Extracting datasets/MNIST/train-images-idx3-ubyte.gz
Extracting datasets/MNIST/train-labels-idx1-ubyte.gz
Extracting datasets/MNIST/t10k-images-idx3-ubyte.gz
Extracting datasets/MNIST/t10k-labels-idx1-ubyte.gz


### Hyperparameters

In [3]:
# Inputs
img_size = 28    # data.size
img_channel = 1  # data.channel
img_size_flat = img_size * img_size * img_channel
print(f'Images »»» Size: {img_size:,}\tChannel: {img_channel:,}\tFlattened: {img_size_flat:,}')

Images »»» Size: 28	Channel: 1	Flattened: 784


In [4]:
# Network
kernel_size = 5
n_noise     = 64
keep_prob   = 0.8

In [5]:
# Training
batch_size    = 24
learning_rate = .01
save_interval = 100
log_interval  = 1000
iterations    = 10000

### Helpers

In [6]:
import matplotlib.pyplot as plt

def visualize(imgs, name=None, smooth=False, **kwargs):
    # Plot images in grid
    grid = int(np.sqrt(len(imgs)))
    # Create figure with sub-plots.
    fig, axes = plt.subplots(grid, grid)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Interpolation type.
        interpolation = 'spline16' if smooth else 'nearest'
        shape = [img_size, img_size]
        ax.imshow(imgs[i].reshape(shape), interpolation=interpolation, **kwargs)
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
    if name:
        plt.suptitle(name)
    plt.show()

In [7]:
def lrelu(x, alpha=0.2):
    return tf.maximum(x, tf.multiply(x, alpha))

In [8]:
def binary_cross_entropy(x, z, eps=1e-12):
    return (-(x * tf.log(z + eps) + (1. - x) * tf.log(1. - z + eps)))

### The Discriminator

In [9]:
def discriminator(X, reuse=None):
    with tf.variable_scope('discriminator', reuse=reuse):
        net = X
        # reshape
        net = tf.reshape(net, [-1, img_size, img_size, img_channel])
        # conv + dropout
        net = tf.layers.conv2d(net, filters=64, kernel_size=5, strides=2, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        # conv + dropout
        net = tf.layers.conv2d(net, filters=64, kernel_size=5, strides=1, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        # conv + dropout
        net = tf.layers.conv2d(net, filters=64, kernel_size=5, strides=1, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        # flatten
        net = tf.contrib.layers.flatten(net)
        # 2 fully connected layers
        net = tf.layers.dense(net, units=128, activation=lrelu)
        net = tf.layers.dense(net, units=1, activation=tf.nn.sigmoid)
        return net

### The Generator

In [10]:
def generator(noise, reuse=None, is_training=False):
    decay = 0.99
    with tf.variable_scope('generator', reuse=reuse):
        net = noise
        d1 = 4
        d2 = 1
        # fully connected + dropout + batch norm
        net = tf.layers.dense(net, units=d1*d1*d2, activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        net = tf.contrib.layers.batch_norm(net, decay=decay, is_training=is_training)
        # reshape + resize
        net = tf.reshape(net, shape=[-1, d1, d1, d2])
        net = tf.image.resize_images(net, size=[7, 7])
        # conv transpose + dropout + batch_norm
        net = tf.layers.conv2d_transpose(net, filters=64, kernel_size=5, strides=2, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        net = tf.contrib.layers.batch_norm(net, decay=decay, is_training=is_training)
        # conv transpose + dropout + batch_norm
        net = tf.layers.conv2d_transpose(net, filters=64, kernel_size=5, strides=2, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        net = tf.contrib.layers.batch_norm(net, decay=decay, is_training=is_training)
        # conv transpose + dropout + batch_norm
        net = tf.layers.conv2d_transpose(net, filters=64, kernel_size=5, strides=1, padding='SAME', activation=lrelu)
        net = tf.nn.dropout(net, keep_prob=keep_prob)
        net = tf.contrib.layers.batch_norm(net, decay=decay, is_training=is_training)
        # conv transpose
        net = tf.layers.conv2d_transpose(net, filters=64, kernel_size=5, strides=1, padding='SAME', activation=tf.nn.sigmoid)
        return net

In [20]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=[None, img_size_flat])
noise = tf.placeholder(tf.float32, shape=[None, n_noise])

In [21]:
G = generator(noise, is_training=True)
Dx = discriminator(X, reuse=None)
Dg = discriminator(G, reuse=True)
print(G)
print(Dx)
print(Dg)

Tensor("generator/conv2d_transpose_4/Sigmoid:0", shape=(?, 28, 28, 64), dtype=float32)
Tensor("discriminator/dense_2/Sigmoid:0", shape=(?, 1), dtype=float32)
Tensor("discriminator_1/dense_2/Sigmoid:0", shape=(?, 1), dtype=float32)


### Loss function

In [13]:
# Discriminator's loss (Real->rated highly, Fake->rated poorly)
loss_d_real = tf.nn.softmax_cross_entropy_with_logits(logits=Dx, labels=tf.ones_like(Dx))
loss_d_fake = tf.nn.softmax_cross_entropy_with_logits(logits=Dg, labels=tf.zeros_like(Dg))
loss_d = tf.reduce_mean(0.5 * (loss_d_real + loss_d_fake), name='loss_d')
# Generator's loss (Generator->rated highly)
loss_g = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=G, labels=tf.ones_like(G)), name='loss_g')

### Optimizer & Regularizer

In [14]:
# Trainable variables for generator & discriminator
d_vars = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
g_vars = [var for var in tf.trainable_variables() if var.name.startswith('generator')]


# Regularizer for generator & discriminator
regularizer = tf.contrib.layers.l2_regularizer(scale=1e-6)
d_reg = tf.contrib.layers.apply_regularization(regularizer=regularizer, weights_list=d_vars)
g_reg = tf.contrib.layers.apply_regularization(regularizer=regularizer, weights_list=g_vars)


# We have to provide the update_ops to our optimizers when applying batch normalization
update_ops = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(control_inputs=update_ops):
    d_global_step = tf.Variable(0, trainable=False, name='d_global_step')
    g_global_step = tf.Variable(0, trainable=False, name='g_global_step')
    # Optimizer for Discriminator
    optimizer_d = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    optimizer_d = optimizer_d.minimize(loss_d + d_reg, global_step=d_global_step, var_list=d_vars)
    # Optimizer for Generator
    optimizer_g = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    optimizer_g = optimizer_g.minimize(loss_g + g_reg, global_step=g_global_step, var_list=g_vars)

## Running the Computational Graph

In [15]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

### Tensorboard

In [16]:
# Tensorboard & Model's directory
tensorboard_dir = 'tensorboard/generate/gan/'
logdir = os.path.join(tensorboard_dir, 'log')
save_path = 'models/generate/gan/'
save_model = os.path.join(save_path, 'model.ckpt')

# Summary
tf.summary.scalar('loss_d_real', loss_d_real)
tf.summary.scalar('loss_d_fake', loss_d_fake)
tf.summary.scalar('loss_d', loss_d)
tf.summary.scalar('loss_g', loss_g)
gen_img = generator(noise, reuse=True, is_training=False)
tf.summary.image('gen_img', gen_img, max_outputs=6)
merged = tf.summary.merge_all()

# Saver & Writer
saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

In [17]:
if tf.gfile.Exists(save_path):
    try:
        sys.stdout.write('INFO: Attempting to restore last checkpoint.\n')
        last_ckpt = tf.train.latest_checkpoint(save_path)
        saver.restore(sess=sess, save_path=last_ckpt)
        sys.stdout.write(f'INFO: Restored last checkpoint from {last_ckpt}\n')
        sys.stdout.flush()
    except Exception as e:
        sys.stderr.write(f'ERR: Could not restore checkpoint. {e}')
        sys.stderr.flush()
else:
    tf.gfile.MakeDirs(save_path)
    sys.stdout.write(f'INFO: Created checkpoint directory: {save_path}\n')
    sys.stdout.flush()

ERR: Could not restore checkpoint. Can't load save_path when it is None.

INFO: Attempting to restore last checkpoint.


### Training

In [None]:
start_time = dt.datetime.now()

for i in range(iterations):
    train_d = True
    train_g = True
    
    X_batch = data.train.next_batch(batch_size=batch_size)[0]
    n = np.random.uniform(low=0.0, high=1.0, size=[batch_size, n_noise])
    feed_dict = {X: X_batch, noise: n}
    # Run the losses
    _d_real, _d_fake, _loss_d, _loss_g, _d_global, _g_global = sess.run([loss_d_real, loss_d_fake, 
                                                                         loss_d, loss_g, 
                                                                         d_global_step, 
                                                                         g_global_step], 
                                                                        feed_dict=feed_dict)
    _d_real, _d_fake = np.mean(_d_real), np.mean(_d_fake)
    
    # Stop training discriminator
    if _loss_g * 1.5 < _loss_d:
        train_d = False
        sys.stderr.write(f'Discriminator stopped training!'
                         f'Real: {_d_real:.2f}\tFake: {_d_fake:.2f}\t'
                         f'Loss: {_loss_d:.4f}')
        sys.stderr.flush()
    # Stop training generator
    if _loss_d * 2 < _loss_g:
        train_g = False
        sys.stderr.write(f'Generator stopped training!'
                         f'Loss: {_loss_g:.4f}')
        sys.stderr.flush()
    
    # Train discriminator
    if train_d:
        sess.run(optimizer_d, feed_dict=feed_dict)
    # Train generator
    if train_g:
        sess.run(optimizer_g, feed_dict=feed_dict)
    
    # Save model & Graph summary
    if i%save_interval == 0:
        saver.save(sess=sess, save_path=save_model, global_step=g_global_step)
        summary = sess.run(merged, feed_dict=feed_dict)
        writer.add_summary(summary=summary, global_step=_g_global)
    # Log generated images @ intervals
    if i%log_interval == 0:
        randoms = np.random.uniform(low=0.0, hight=1.0, size=[9, n_noise])
        gen_imgs = generator(randoms, is_training=False)
        imgs = sess.run(gen_imgs, feed_dict={noise: randoms})
        visualize(imgs, name=f'Iteration: {i+1}', smooth=True, cmap='gray')
    sys.stdout.write(f'\rIter: {i+1:,}\tg_Global: {_g_global:,}\td_Global: {_d_global:,}\t'
                     f'Discriminator »»» Real: {_d_real:.2f}\tFake: {_d_fake:.2f}\tLoss: {_loss_d:.2f}'
                     f'Generator »»» Loss: {_loss_g:.2f}')
    sys.stdout.flush()