# Deep Convolutional Generative Adversarial Network (DCGAN)

In [None]:
import os
import sys
import datetime as dt

import numpy as np
import tensorflow as tf
import tflearn
import matplotlib.pyplot as plt

from dataset import ImageDataset

%matplotlib inline

### Loding dataset

In [None]:
data = ImageDataset(data_dir='datasets/flowers/', size=50, flatten=False, grayscale=False)
# data.create()
# data.save('datasets/saved/data.pkl')
data = data.load('datasets/saved/data.pkl')

In [None]:
X_train, _, X_test, _ = data.train_test_split(test_size=0.1)
print(X_train.shape, X_test.shape)

In [None]:
image_size = data.size
image_channel = data.channels
image_size_flat = image_size * image_size * image_channel
print('size: {}\tchannel: {}\t flattened: {:,}'.format(image_size, image_channel, image_size_flat))

In [None]:
keep_prob = 0.8

# Training
learning_rate = 1e-3
batch_size = 24
iterations = 10000   # 10k
save_interval = 100  # 100
log_interval = 1000  # 1k

In [None]:
def plot_images(imgs, name=None):
    grid = int(np.sqrt(len(imgs)))
    fig, axes = plt.subplots(grid, grid)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)
    if name:
        plt.suptitle(name)
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i].reshape([image_size, image_size]), cmap='binary', interpolation='bicubic')
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

### Discriminator `(Deep Convolutional Neural Net)`

In [None]:
def discriminator(image, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        # Building 'AlexNet'
        network = tf.reshape(image, shape=[-1, image_size, image_size, image_channel])
        # 1st convnet
        network = tflearn.conv_2d(network, nb_filter=16, filter_size=11, strides=4, activation='relu',)
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # 2nd convnet
        network = tflearn.conv_2d(network, nb_filter=32, filter_size=5, activation='relu')
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # 3, 4, 5 convnet
        network = tflearn.conv_2d(network, nb_filter=64, filter_size=3, activation='relu')
        network = tflearn.conv_2d(network, nb_filter=128, filter_size=3, activation='relu')
        network = tflearn.conv_2d(network, nb_filter=64, filter_size=3, activation='relu')
        network = tflearn.max_pool_2d(network, kernel_size=3, strides=2)
        network = tflearn.batch_normalization(network)
        # Flatten
        network = tflearn.flatten(network)
        # 1st fully connected
        network = tflearn.fully_connected(network, n_units=256, activation='tanh')
        network = tflearn.dropout(network, keep_prob)
        # 2nd fully connected
        network = tflearn.fully_connected(network, n_units=128, activation='tanh')
        network = tflearn.dropout(network, keep_prob)
        # 3rd fully connected
        network = tflearn.fully_connected(network, n_units=1, activation='linear')
    return network

## Generator `(Deep Deconvolutional Neural Net)`

In [None]:
def generator(noise, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        x = tflearn.fully_connected(noise, n_units=256, activation='tanh')
        x = tflearn.batch_normalization(x)
        x = tflearn.fully_connected(x, n_units=1024, activation='tanh')  # 8*8*16=1024
        x = tf.reshape(x, shape=[-1, 8, 8, 16])
        x = tflearn.upsample_2d(x, 2)
        x = tflearn.conv_2d(x, nb_filter=64, filter_size=5, activation='relu')
        x = tflearn.conv_2d(x, nb_filter=128, filter_size=5, activation='relu')
        x = tflearn.conv_2d(x, nb_filter=64, filter_size=5, activation='relu')
        x = tflearn.batch_normalization(x)
        x = tflearn.upsample_2d(x, kernel_size=2)
        x = tflearn.conv_2d(x, nb_filter=32, filter_size=5, activation='relu')
        x = tflearn.batch_normalization(x)
        x = tflearn.conv_2d(x, nb_filter=image_size_flat, filter_size=5, activation='relu')
        x = tflearn.upsample_2d(x, kernel_size=2)
        x = tflearn.conv_2d(x, nb_filter=16, filter_size=5, activation='relu')
        x = tflearn.layers.flatten(x)
        x = tflearn.fully_connected(x, n_units=image_size_flat, activation='sigmoid')
        x = tf.reshape(x, shape=[-1, image_size, image_size, image_channel])
        return x

In [None]:
z_dim = 200

tf.reset_default_graph()

# gen_input = tflearn.input_data(shape=[None, z_dim], name='input_noise')
# disc_input = tflearn.input_data(shape=[None, image_size_flat], name='disc_input')

X_placeholder = tf.placeholder(tf.float32, shape=[None, image_size_flat], name='X_placeholder')
Z_placeholder = tf.placeholder(tf.float32, shape=[None, z_dim], name='Z_placeholder')

Gz = generator(Z_placeholder)
Dx = discriminator(X_placeholder)
Dg = discriminator(Gz, reuse=True)

## Loss functions

In [None]:
# Discriminator's loss
d_real_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Dx, labels=tf.ones_like(Dx)))
d_fake_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Dg, labels=tf.zeros_like(Dg)))
# Generator's loss
gen_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=Gz, labels=tf.ones_like(Gz)))

In [None]:
# Get variables for the discriminator & generator
dis_vars = tflearn.get_layer_variables_by_scope('discriminator')
gen_vars = tflearn.get_layer_variables_by_scope('generator')

# Discriminator's optimizer
disc_optimizer_r = tf.train.AdadeltaOptimizer(learning_rate=learning_rate).minimize(d_real_loss, var_list=dis_vars)
disc_optimizer_f = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(d_fake_loss, var_list=dis_vars)

# Generator's Optimizer
gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(gen_loss, var_list=gen_vars)

## Tensorflow `Session`

In [None]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

## Tensorboard

In [None]:
tensorboard_dir = 'tensorboard/'
logdir = os.path.join(tensorboard_dir, 'log/')
save_path = 'models/'


tf.summary.scalar('discriminator_loss_real', d_real_loss)
tf.summary.scalar('discriminator_loss_fake', d_fake_loss)
tf.summary.scalar('discriminator_add_loss', d_real_loss + d_fake_loss)
tf.summary.scalar('generator_loss', gen_loss)
tb_img = generator(Z_placeholder, reuse=True)  # Tensorboard image
tf.summary.image('generator_image', tb_img, max_outputs=4)

saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

if tf.gfile.Exists(save_path):
    if len(os.listdir(save_path)) > 1:
        saver.restore(sess=sess, save_path=save_path)
else:
    tf.gfile.MakeDirs(save_path)

## Training

In [None]:
start_time = dt.datetime.now()

# Initial discriminator training
for i in range(1, 101):
    fake_batch = np.random.normal(0, 1, [batch_size, z_dim])
    real_batch = data.next_batch(batch_size=batch_size)[0]
    real_batch = np.reshape(real_batch, [-1, image_size_flat])
    _, _, _fake_loss, _real_loss = sess.run([disc_optimizer_f, disc_optimizer_r, d_fake_loss, d_real_loss],
                                            feed_dict={X_placeholder: real_batch, Z_placeholder: fake_batch})
    sys.stdout.write('\rIter: {:,}\tLoss: real = {:.4f}\t fake = {:.4f}'.format(i, _fake_loss, _real_loss))

In [None]:
start_time = dt.datetime.now()

# Training the discriminator & generator together
for i in range(1, iterations+1):
    
    # Train discriminator on real & fake images
    fake_batch = np.random.normal(0, 1, size=[batch_size, z_dim])
    real_batch = data.next_batch(batch_size=batch_size)[0]
    real_batch = np.reshape(real_batch, [-1, image_size_flat])
    _, _, _fake_loss, _real_loss = sess.run([disc_optimizer_f, disc_optimizer_r, d_fake_loss, d_real_loss],
                                                feed_dict={X_placeholder: real_batch, Z_placeholder: fake_batch})
    
    # Train generator to generate images
    gen_img = np.random.normal(0, 1, size=[batch_size, z_dim])
    _, _gen_loss = sess.run([gen_optimizer, gen_loss], feed_dict={Z_placeholder: gen_img})
    
    # Saving
    if i%save_interval == 0:
        saver.save(sess=sess, save_path=save_path)
        summary = sess.run(merged, feed_dict={X_placeholder: real_batch, Z_placeholder:fake_batch})
        writer.add_summary(summary=summary, global_step=i)
    
    # Logging: Displaying generated images @ intervals
    if i%log_interval == 0:
        test_noise  = np.random.normal(0, 1, size=[9, z_dim])
        gen_test = generator(Z_placeholder)
        test_imgs = sess.run(gen_test, feed_dict={Z_placeholder: test_noise})
        plot_images(test_imgs, name='Iteration: {:,}\tTest image'.format(i))
    
    # Log training metrics
    sys.stdout.write('\rIter: {:,}Generator: {:.4f}\tDiscriminator: real = {:.4f} fake = {:.4f}\tTime taken: {}'.format(
        i,_gen_loss, _fake_loss, _real_loss, dt.datetime.now() - start_time))