# Deep Convolutional Generative Adversarial Network (DCGAN)

In [None]:
import os
import sys
import datetime.datetime as datetime

import numpy as np
import tensorflow as tf
import tflearn
import matplotlib.pyplot as plt

from dataset import ImageDataset

%matplotlib inline

### Loding dataset

In [None]:
# Data & save directory.
data_dir, save_dir = 'datasets/flowers/', 'saved/'
save_data = os.path.join(save_dir, 'data.pkl')

# Preprocess image data.
data = ImageDataset(data_dir=data_dir, size=50, flatten=False, grayscale=False)

# Create data if it's not yet saved, otherwise; load it.
if os.path.isfile(save_data):
    data = data.load(save_data)
else:
    data.create()
    data.save(save_data)

In [None]:
# Split samples into training and testing sets.
X_train, _, X_test, _ = data.train_test_split(test_size=0.1)
print(X_train.shape, X_test.shape)

In [None]:
image_size = data.size
image_channel = data.channels
image_size_flat = image_size * image_size * image_channel
print('size: {}\tchannel: {}\t flattened: {:,}'.format(image_size, image_channel, image_size_flat))

In [None]:
keep_prob = 0.8

# Training.
learning_rate = 1e-3
batch_size = 24
iterations = 10000   # 10k
save_interval = 100  # 100
log_interval = 1000  # 1k

In [None]:
def plot_images(imgs, name=None):
    grid = int(np.sqrt(len(imgs)))
    fig, axes = plt.subplots(grid, grid)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)
    if name:
        plt.suptitle(name)
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i].reshape([image_size, image_size]), cmap='binary', interpolation='bicubic')
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

### Discriminator `(Deep Convolutional Neural Net)`

In [None]:
def discriminator(image, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        # Building 'AlexNet'
        network = tf.reshape(image, shape=[-1, image_size, image_size, image_channel])
        # 1st convnet
        network = tflearn.conv_2d(network, nb_filter=16, filter_size=11, strides=4, activation='relu',)
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # 2nd convnet
        network = tflearn.conv_2d(network, nb_filter=32, filter_size=5, activation='relu')
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # 3, 4, 5 convnet
        network = tflearn.conv_2d(network, nb_filter=64, filter_size=3, activation='relu')
        network = tflearn.conv_2d(network, nb_filter=128, filter_size=3, activation='relu')
        network = tflearn.conv_2d(network, nb_filter=64, filter_size=3, activation='relu')
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # Flatten
        network = tflearn.flatten(network)
        # 1st fully connected
        network = tflearn.fully_connected(network, n_units=256, activation='tanh')
        network = tflearn.dropout(network, keep_prob)
        # 2nd fully connected
        network = tflearn.fully_connected(network, n_units=128, activation='tanh')
        network = tflearn.dropout(network, keep_prob)
        # 3rd fully connected
        network = tflearn.fully_connected(network, n_units=1, activation='linear')
    return network

## Generator `(Deep Deconvolutional Neural Net)`

In [None]:
def generator(noise, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        x = tflearn.fully_connected(noise, n_units=256, activation='tanh')
        x = tflearn.batch_normalization(x)
        x = tflearn.fully_connected(x, n_units=1024, activation='tanh')  # 8*8*16=1024
        x = tf.reshape(x, shape=[-1, 8, 8, 16])
        x = tflearn.upsample_2d(x, 2)
        x = tflearn.conv_2d(x, nb_filter=64,  filter_size=5, activation='relu')
        x = tflearn.conv_2d(x, nb_filter=128, filter_size=5, activation='relu')
        x = tflearn.conv_2d(x, nb_filter=64,  filter_size=5, activation='relu')
        x = tflearn.batch_normalization(x)
        x = tflearn.upsample_2d(x, kernel_size=2)
        x = tflearn.conv_2d(x, nb_filter=32, filter_size=5, activation='relu')
        x = tflearn.batch_normalization(x)
        x = tflearn.conv_2d(x, nb_filter=image_size_flat, filter_size=5, activation='relu')
        x = tflearn.upsample_2d(x, kernel_size=2)
        x = tflearn.conv_2d(x, nb_filter=16, filter_size=5, activation='relu')
        x = tflearn.layers.flatten(x)
        x = tflearn.fully_connected(x, n_units=image_size_flat, activation='sigmoid')
        x = tf.reshape(x, shape=[-1, image_size, image_size, image_channel])
        return x

In [None]:
z_dim = 200

tf.reset_default_graph()

# gen_input = tflearn.input_data(shape=[None, z_dim], name='input_noise')
# disc_input = tflearn.input_data(shape=[None, image_size_flat], name='disc_input')

X_placeholder = tf.placeholder(tf.float32, shape=[None, image_size_flat], name='X_placeholder')
Z_placeholder = tf.placeholder(tf.float32, shape=[None, z_dim], name='Z_placeholder')

Gz = generator(Z_placeholder)
Dx = discriminator(X_placeholder)
Dg = discriminator(Gz, reuse=True)

## Loss functions

In [None]:
# Discriminator's loss
with tf.name_scope('disc_loss'):
    # Real.
    d_real_loss = tf.nn.softmax_cross_entropy_with_logits(logits=Dx, labels=tf.ones_like(Dx))
    d_real_loss = tf.reduce_mean(d_real_loss, name='real_loss')
    # Fake.
    d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(logits=Dg, labels=tf.zeros_like(Dg))
    d_fake_loss = tf.reduce_mean(d_fake_loss, name='fake_loss')

# Generator's loss
with tf.name_scope('gen_loss'):
    gen_loss = tf.nn.softmax_cross_entropy_with_logits(logits=Gz, labels=tf.ones_like(Gz))
    gen_loss = tf.reduce_mean(gen_loss, name='gen_loss')

In [None]:
# Get variables for the discriminator & generator
dis_vars = tflearn.get_layer_variables_by_scope('discriminator')
gen_vars = tflearn.get_layer_variables_by_scope('generator')


# Discriminator's optimizer.
with tf.name_scope('disc_optimizer'):
    # Real discriminator optimizer.
    global_step_real = tf.Variable(0, trainable=False, name='global_step_real')
    disc_opt_real = tf.train.AdadeltaOptimizer(learning_rate=learning_rate)
    disc_opt_real = disc_opt_real.minimize(d_real_loss, var_list=dis_vars,
                                           global_step=global_step_real)
    
    # Fake discriminator optimizer.
    global_step_fake = tf.Variable(0, trainable=False, name='global_step_fake')
    disc_opt_fake = tf.train.AdamOptimizer(learning_rate=learning_rate)
    disc_opt_fake = disc_opt_fake.minimize(d_fake_loss, var_list=dis_vars,
                                           global_step=global_step_fake)

# Generator's Optimizer.
with tf.name_scope('gen_optimizer'):
    gen_global_step = tf.Variable(0, trainable=False, name='global_step')
    gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    gen_optimizer = gen_optimizer.minimize(gen_loss, var_list=gen_vars,
                                           global_step=gen_global_step)

## Tensorflow `Session`

In [None]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Tensorboard

In [None]:
tensorboard_dir = os.path.join(save_dir, 'tensorboard/')
logdir = os.path.join(tensorboard_dir, 'log/')

save_dir = os.path.join(save_dir, 'models/')
save_path = os.path.join(save_path, 'model.ckpt')


# Tensorboard sumamry.
with tf.name_scope('summary'):
    # Scalar summary.
    tf.summary.scalar('discriminator_loss_real', d_real_loss)
    tf.summary.scalar('discriminator_loss_fake', d_fake_loss)
    tf.summary.scalar('discriminator_add_loss', d_real_loss + d_fake_loss)
    tf.summary.scalar('generator_loss', gen_loss)
    
    # Image summary.
    tb_img = generator(Z_placeholder, reuse=True)  # Tensorboard image
    tf.summary.image('generator_image', tb_img, max_outputs=4)


# Saver & Writer object.
saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

### Restoring last checkpoint

In [None]:
# Restoring  last checkpoint.
if tf.gfile.Exists(save_dir):
    try:
        print('INFO: Attempting to restore last checkpoint!')
        last_ckpt = tf.train.latest_checkpoint(save_dir)
        
        # Restoring last checkpoint into default graph.
        saver.restore(sess=sess, save_path=last_ckpt)
        print('INFO: Checkpoint restored! @{}'.format(last_ckpt))
    except Exception as e:
        print('WARNING: Could not retrieve last checkpoint. {}'.format(e))
else:
    tf.gfile.MakeDirs(save_dir)
    print('INFO: Created checkpoint directory - {}'.format(save_dir))

## Training

In [None]:
# Start the discriminator training clock.
start_time = datetime.now()


# Initial discriminator training.
for i in range(1, 101):
    try:
        # Retrive training batches.
        fake_batch = np.random.normal(0, 1, [batch_size, z_dim])
        real_batch = data.next_batch(batch_size=batch_size)[0]
        real_batch = np.reshape(real_batch, [-1, image_size_flat])

        # Train the discriminator model.
        _, _, _fake_loss, _real_loss = sess.run([disc_opt_fake, disc_opt_real, 
                                                 d_fake_loss, d_real_loss],
                                                feed_dict={X_placeholder: real_batch, 
                                                           Z_placeholder: fake_batch})

        # Log progress.
        print(f'\rIter: {i:,}\tLoss:-real: {_real_loss:.4f}'
              f'\tfake: {_fake_loss:.4f}\tTime taken = '
              f'{datetime.now() - start_time}', end='')
    except KeyboardInterrupt:
        print('Training stopped!')
        
        # End the loop smoothly.
        break

In [None]:
# Start the training clock.
start_time = datetime.now()

# Training the discriminator & generator together.
for i in range(1, iterations+1):
    try:
        # Train discriminator on real & fake images.
        fake_batch = np.random.normal(0, 1, size=[batch_size, z_dim])
        real_batch = data.next_batch(batch_size=batch_size)[0]
        real_batch = np.reshape(real_batch, [-1, image_size_flat])

        # Train the discriminator.
        _, _, _fake_loss, _real_loss = sess.run([disc_opt_fake, disc_opt_real, 
                                                 d_fake_loss, d_real_loss],
                                                    feed_dict={X_placeholder: real_batch, 
                                                               Z_placeholder: fake_batch})

        # Train generator to generate images.
        gen_img = np.random.normal(0, 1, size=[batch_size, z_dim])
        _, _gen_loss, _global_gen = sess.run([gen_optimizer, gen_loss, gen_global_step], 
                                             feed_dict={Z_placeholder: gen_img})

        # Saving.
        if i % save_interval == 0:
            # Save the generative model.
            saver.save(sess=sess, save_path=save_path, global_step=gen_global_step)
            # Update tensorboard.
            summary = sess.run(merged, feed_dict={X_placeholder: real_batch,
                                                  Z_placeholder:fake_batch})
            writer.add_summary(summary=summary, global_step=_global_gen)

        # Logging: Displaying generated images @ intervals.
        if i % log_interval == 0:
            test_noise  = np.random.normal(0, 1, size=[9, z_dim])
            gen_test = generator(Z_placeholder)
            test_imgs = sess.run(gen_test, feed_dict={Z_placeholder: test_noise})
            plot_images(test_imgs, name='Iteration: {:,}\tTest image'.format(i))

        # Log training metrics.
        print(f'\rIter: {i:,}\tGen loss: {_gen_loss:.4f}'
              f'\tDiscriminator: real = {_real_loss:.4f}'
              f'\tfake = {_fake_loss:.4f}\tTime taken: '
              '{datetime.now() - start_time}', end='')
    
    except KeyboardInterrupt:
        # End training interruption smoothly.
        print(f'Training stopped @ iter {i:,}. Saving model.')
        
        # Save the generative model.
        saver.save(sess=sess, save_path=save_path, global_step=gen_global_step)
        
        # Update tensorboard.
        summary = sess.run(merged, feed_dict={X_placeholder: real_batch,
                                      Z_placeholder:fake_batch})
        writer.add_summary(summary=summary, global_step=_global_gen)
        
        print('Model saved! Ending training.')
        
        break