In [None]:
wd = %pwd
print('Current directory:', wd)
if wd.endswith('models'):
    %cd ..

In [None]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import numpy as np
import sys, os, time, datetime

from models.base import BaseModel
from resources.flags import FLAGS, define_flags

In [None]:
# This can only be run once.
define_flags()

In [None]:
## Helper functions:
# TODO: Move to separate file.

def next_batch(arr, batch_size):
    num_batches = int(len(arr) / batch_size)
    for i in range(0, num_batches * batch_size, batch_size):
        yield arr[i:i+batch_size]
    yield arr[num_batches*batch_size:]

def noise(size, dist='uniform'):
    if dist=='uniform':
        return np.random.uniform(-1, 1, size=size)
    elif dist == 'normal':
        return np.random.normal(size=size)
    elif dist == 'linspace':
        n, dim = np.sqrt(size[0]).astype(np.int32), size[1];
        interpolated_noise = []
        starts, ends = noise((n, dim)), noise((n, dim))
        for i in range(n):
            for w in np.linspace(0, 1, n):
                interpolated_noise.append(starts[i] + (ends[i] + starts[i]) * w)
        return np.asarray(interpolated_noise)
    
def shuffle(a, b):
    # Generate the permutation index array.
    permutation = np.random.permutation(a.shape[0])
    # Shuffle the arrays by giving the permutation in the square brackets.
    return a[permutation], b[permutation]

def tile_images(images, num_x, num_y, h, w):
    res = tf.zeros((num_y*h, num_x*w))
    index = -1
    rows = []
    for i in range(0, num_y):
        row = []
        for j in range(0, num_x):
            index += 1
            row.append(tf.reshape(images[index], (h,w)))
        rows.append(tf.concat(row, 1))
    res = tf.concat(rows, 0)
    print("res shape:", res.shape)
    return tf.reshape(res, (1, num_y*h, num_x*w, 1))

## Read the input data

In this case, MNIST + batch and shuffle it. In our case, it will be quite different.

In [None]:
from tensorflow.examples.tutorials.mnist import input_data

# Reset Tensorflow graph and create new session
tf.reset_default_graph()
session = tf.Session()

def load_data(data_dir):
    """Returns training and test tf.data.Dataset objects."""
    data = input_data.read_data_sets(data_dir, one_hot=True)
    #train_ds = tf.data.Dataset.from_tensor_slices((data.train.images,
    #                                               data.train.labels))
    #test_ds = tf.data.Dataset.from_tensors(
    #   (data.test.images, data.test.labels))
    return (data.train, data.test)


device, data_format = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu:
    device, data_format = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))

# Load the datasets
train_ds, test_ds = load_data(FLAGS.in_data_dir)
#train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size)
train_X, train_Y = shuffle(train_ds.images, train_ds.labels)

In [None]:
# Setup constatnts
IMAGE_PIXELS = 784
NOISE_SIZE = 100
images_input = tf.placeholder(tf.float32, shape=(None, IMAGE_PIXELS))
noise_input = tf.placeholder(tf.float32, shape=(None, NOISE_SIZE))
noise_input_interpolated = tf.placeholder(tf.float32, shape=(None, NOISE_SIZE))

In [None]:
## Discriminator
def discriminator(X, reuse):
    with tf.variable_scope("Discriminator", reuse=reuse):
        # Layer 1
        dx = tf.layers.dense(X, units=1024, kernel_initializer=tf.random_normal_initializer(stddev=0.02), activation=tf.nn.relu, name='fc1')
        # Layer 2
        #dx = tf.layers.dense(dx, units=512, activation=tf.nn.relu, name='fc2')
        # Layer 3
        #dx = tf.layers.dense(dx, units=256, activation=tf.nn.relu, name='fc3')
        # Layer 4
        d_out = tf.layers.dense(dx, units=1, kernel_initializer=tf.random_normal_initializer(stddev=0.02), name='fc_out')
        return d_out

In [None]:
## Generator
def generator(X, reuse=False):
    with tf.variable_scope('Generator', reuse=reuse):
        # Layer 1
        gx = tf.layers.dense(X, units=128, activation=tf.nn.relu, name='fc1')
        # Layer 2
        #gx = tf.layers.dense(gx, units=512, activation=tf.nn.relu, name='fc2')
        # Layer 3
        #gx = tf.layers.dense(gx, units=1024, activation=tf.nn.relu, name='fc3')
        # Layer 4
        g_out = tf.layers.dense(gx, units=784, activation=tf.nn.sigmoid, name='fc_out')
        return g_out

In [None]:
## Losses
g_sample = generator(noise_input)
d_real = discriminator(images_input, reuse=False)
d_fake = discriminator(g_sample, reuse=True)

d_loss_real = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_real, labels=tf.ones_like(d_real)))
d_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_fake, labels=tf.zeros_like(d_fake)))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_fake, labels=tf.ones_like(d_fake)))

## Summaries
tiled_image_random = tile_images(g_sample, 6,6,28,28)
tiled_image_interpolated = tile_images(generator(noise_input_interpolated, reuse=True), 6,6,28,28)
gen_image_summary_op = tf.summary.image('generated_images', tiled_image_random, max_outputs=1)
gen_image_summary_interpolated_op = tf.summary.image('generated_images_interpolated', tiled_image_interpolated, max_outputs=1)

# Optimizers
t_vars = tf.trainable_variables()
d_opt = tf.train.AdamOptimizer(2e-4).minimize(d_loss, var_list=[var for var in t_vars if 'Discriminator' in var.name])
g_opt = tf.train.AdamOptimizer(2e-4).minimize(g_loss, var_list=[var for var in t_vars if 'Generator' in var.name])

#saver = tf.train.Saver(max_to_keep=1)

In [None]:
num_test_samples=36
test_noise_random = noise((num_test_samples, NOISE_SIZE))
test_noise_interpolated = noise((num_test_samples, NOISE_SIZE), dist='linspace')

In [None]:
num_epochs = 200

# Start interactive session
session = tf.InteractiveSession()
# Init Variables
tf.global_variables_initializer().run()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
# Create writer instance and write summary
writer = tf.summary.FileWriter(logdir=FLAGS.out_data_dir + "/gan_mnist_{}".format(timestamp), graph=session.graph)

In [None]:
# Iterate through epochs
BATCH_SIZE = 100
for epoch in range(FLAGS.epochs):
    print("Epoch %d" % epoch)
    for n_batch, batch in enumerate(next_batch(train_X, BATCH_SIZE)):
        # 1. Train Discriminator
        #X_batch = images_to_vectors(batch.permute(0, 2, 3, 1).numpy())
        batch_noise = noise((BATCH_SIZE, NOISE_SIZE))
        feed_dict = {images_input: batch, noise_input: batch_noise}
        _, d_error, d_pred_real, d_pred_fake = session.run(
            [d_opt, d_loss, d_real, d_fake], feed_dict=feed_dict)

        # 2. Train Generator
        feed_dict = {noise_input: batch_noise}
        _, g_error = session.run(
            [g_opt, g_loss], feed_dict=feed_dict
        )

        if n_batch % 500 == 0:
        #    display.clear_output(True)
            # Generate images from test noise
            test_image_tiled, gen_image_summary, test_image_tiled_interpolated, gen_image_summary_interpolated = session.run(
                [tiled_image_random, gen_image_summary_op, tiled_image_interpolated, gen_image_summary_interpolated_op], 
                 feed_dict={noise_input: test_noise_random, noise_input_interpolated: test_noise_interpolated}
            )

            writer.add_summary(gen_image_summary, global_step=epoch*1000+n_batch)
            writer.add_summary(gen_image_summary_interpolated, global_step=epoch*1000+n_batch)
            
            print("Epoch: {}, Batch: {}, D_Loss: {}, G_Loss: {}".format(epoch, n_batch, d_error, g_error))
        #    test_images = vectors_to_images(test_images)
        #    # Log Images
        #    logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches, format='NHWC');
        #    # Log Status
        #    logger.display_status(
        #        epoch, num_epochs, n_batch, num_batches,
        #        d_error, g_error, d_pred_real, d_pred_fake
        #    )