# WGANGP implement in keras
* paper:[Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028)
* environment:keras==2.0.8 tensorflow==1.2
* dataset:celeba

## Common functions

### Loss functions

In [None]:
import tensorflow as tf
def mae(pred, target, name='mae'):
    return tf.reduce_mean(tf.abs(pred - target), name=name)

def mse(pred, target, name='mse'):
    return tf.reduce_mean(tf.square(pred - target), name=name)

def pixel_rmse(pred, target, name='rmse'):
    return tf.sqrt(tf.reduce_mean(tf.square(pred - target), name=name))

def binary_cross_entropy_with_logits(pred, target):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=target))

def wasserstein(pred, target):
    return tf.reduce_mean(pred * target)

### logging utils

In [None]:
from colorama import init
from colorama import Fore, Back, Style
import tensorflow as tf
from terminaltables import SingleTable

def print_table(TABLE_DATA):
    table_instance = SingleTable(TABLE_DATA, "")
    table_instance.justify_columns[2] = 'right'
    print(table_instance.table)

def print_bright(s):
    init()
    print(Style.BRIGHT + s + Style.RESET_ALL)

def print_green(info, value=""):
    print(Fore.GREEN + "[%s] " % info + Style.RESET_ALL + str(value))

def print_red(info, value=""):
    print(Fore.RED + "[%s] " % info + Style.RESET_ALL + str(value))

def print_session():
    FLAGS = tf.app.flags.FLAGS
    print_bright("\nSetting up TF session:")
    for key in FLAGS.__dict__["__flags"].keys():
        if "dir" not in key:
            print_green(key, FLAGS.__dict__["__flags"][key])
    print_bright("\nConfiguring directories:")
    for d in [FLAGS.log_dir, FLAGS.model_dir, FLAGS.fig_dir]:
        # Clear directories by default
        if tf.gfile.Exists(d):
            print_red("Deleting", d)
            tf.gfile.DeleteRecursively(d)
    for d in [FLAGS.log_dir, FLAGS.model_dir, FLAGS.fig_dir]:
        print_green("Creating", d)
        tf.gfile.MakeDirs(d)

def print_initialize():
    print_bright("\nInitialization:")
    print_green("Created session saver")
    print_green("Ran init ops")

def print_summaries():
    print_bright("\nSummaries:")
    list_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    for t in list_summaries:
        print_green(t.name)

def print_queues():
    print_bright("\nQueues:")
    print_green("Created coordinator")
    print_green("Started queue runner")

def print_check_data(out, list_data):
    print
    TABLE_DATA = (('Variable Name', 'Shape', "Min value", "Max value"),)
    for o, t in zip(out, list_data):
        TABLE_DATA += (tuple([t.name, str(o.shape), "%.3g" % o.min(), "%.3g" % o.max()]),)
    print_table(TABLE_DATA)

### visualization_utils

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt

FLAGS = tf.app.flags.FLAGS

def save_image(data, data_format, e, suffix=None):
    """Saves a picture showing the current progress of the model"""

    X_G, X_real = data

    Xg = X_G[:8]
    Xr = X_real[:8]

    if data_format == "NHWC":
        X = np.concatenate((Xg, Xr), axis=0)
        list_rows = []
        for i in range(int(X.shape[0] / 4)):
            Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=1)
            list_rows.append(Xr)

        Xr = np.concatenate(list_rows, axis=0)

    if data_format == "NCHW":
        X = np.concatenate((Xg, Xr), axis=0)
        list_rows = []
        for i in range(int(X.shape[0] / 4)):
            Xr = np.concatenate([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
            list_rows.append(Xr)

        Xr = np.concatenate(list_rows, axis=1)
        Xr = Xr.transpose(1,2,0)

    if Xr.shape[-1] == 1:
        plt.imshow(Xr[:, :, 0], cmap="gray")
    else:
        plt.imshow(Xr)
    plt.axis("off")
    if suffix is None:
        plt.savefig(os.path.join(FLAGS.fig_dir, "current_batch_%s.png" % e))
    else:
        plt.savefig(os.path.join(FLAGS.fig_dir, "current_batch_%s_%s.png" % (suffix, e)))
    plt.clf()
    plt.close()


def get_stacked_tensor(X1, X2):

    X = tf.concat((X1[:16], X2[:16]), axis=0)
    list_rows = []
    for i in range(8):
        Xr = tf.concat([X[k] for k in range(4 * i, 4 * (i + 1))], axis=2)
        list_rows.append(Xr)

    X = tf.concat(list_rows, axis=1)
    X = tf.transpose(X, (1,2,0))
    X = tf.expand_dims(X, 0)

    return X

### train utils

In [None]:
import tensorflow as tf
import random
import numpy as np
import sys
sys.path.append("../utils")
import logging_utils as lu


def setup_session():

    lu.print_session()

    FLAGS = tf.app.flags.FLAGS

    # Create session
    config = tf.ConfigProto()
    if FLAGS.use_XLA:
        config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    sess = tf.Session(config=config)

    # Setup directory to save model
    for d in [FLAGS.log_dir, FLAGS.model_dir, FLAGS.fig_dir]:
        # Clear directories by default
        if tf.gfile.Exists(d):
            tf.gfile.DeleteRecursively(d)
        tf.gfile.MakeDirs(d)

    # Initialize all RNGs with a deterministic seed
    with sess.graph.as_default():
        tf.set_random_seed(FLAGS.random_seed)

    random.seed(FLAGS.random_seed)
    np.random.seed(FLAGS.random_seed)

    return sess


def initialize_session(sess):

    saver = tf.train.Saver()

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    lu.print_initialize()

    return saver


def add_gradient_summary(list_gradvar):
    # Add summary for gradients
    for g,v in list_gradvar:
        if g is not None:
            tf.summary.histogram(v.name + "/gradient", g)


def manage_queues(sess):

    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)

    lu.print_queues()

    return coord


def manage_summaries(sess):

    FLAGS = tf.app.flags.FLAGS
    writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    lu.print_summaries()

    return writer


def check_data(out, list_data):

    lu.print_check_data(out, list_data)


### data utils

In [None]:
import os
import glob
import numpy as np
import tensorflow as tf


def normalize_image(image):

    image = tf.cast(image, tf.float32) / 255.
    image = (image - 0.5) / 0.5
    return image


def unnormalize_image(image, name=None):

    image = (image * 0.5 + 0.5) * 255.
    image = tf.cast(image, tf.uint8, name=name)
    return image


def input_data(sess):

    FLAGS = tf.app.flags.FLAGS

    list_images = glob.glob(os.path.join(FLAGS.celebA_path, "*.jpg"))

    # Read each JPEG file

    with tf.device('/cpu:0'):

        reader = tf.WholeFileReader()
        filename_queue = tf.train.string_input_producer(list_images)
        key, value = reader.read(filename_queue)
        channels = FLAGS.channels
        image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image")
        image.set_shape([None, None, channels])

        # Crop and other random augmentations
        image = tf.image.random_flip_left_right(image)
        # image = tf.image.random_saturation(image, .95, 1.05)
        # image = tf.image.random_brightness(image, .05)
        # image = tf.image.random_contrast(image, .95, 1.05)

        # Center crop
        image = tf.image.central_crop(image, FLAGS.central_fraction)

        # Resize
        image = tf.image.resize_images(image, (FLAGS.img_size, FLAGS.img_size), method=tf.image.ResizeMethod.AREA)

        # Normalize
        image = normalize_image(image)

        # Format image to correct ordering
        if FLAGS.data_format == "NCHW":
            image = tf.transpose(image, (2,0,1))

        # Using asynchronous queues
        img_batch = tf.train.batch([image],
                                   batch_size=FLAGS.batch_size,
                                   num_threads=FLAGS.num_threads,
                                   capacity=2 * FLAGS.num_threads * FLAGS.batch_size,
                                   name='X_real_input')

        return img_batch


def sample_batch(X, batch_size):

    idx = np.random.choice(X.shape[0], batch_size, replace=False)
    return X[idx]


## Dataset

## Create models

In [None]:
import sys
import tensorflow as tf
import collections
import layers

class Model(object):

    def __init__(self, name):
        self.name = name

    def get_trainable_variables(self):
        t_vars = tf.trainable_variables()
        t_vars_model = {v.name: v for v in t_vars if self.name in v.name}
        return t_vars_model


class Generator(Model):
    def __init__(self, list_filters, list_kernel_size, list_strides, list_padding, output_shape,
                 name="generator", batch_size=32, filters=512, dset="celebA", data_format="NCHW"):

        super(Generator, self).__init__(name)

        self.data_format = data_format

        if self.data_format == "NCHW":
            self.output_h = output_shape[1]
            self.output_w = output_shape[2]
        else:
            self.output_h = output_shape[0]
            self.output_w = output_shape[1]

        if dset == "mnist":
            self.start_dim = int(self.output_h / 4)
            self.nb_upconv = 2
        else:
            self.start_dim = int(self.output_h / 16)
            self.nb_upconv = 4

        self.output_shape = output_shape
        self.dset = dset
        self.name = name
        self.batch_size = batch_size
        self.filters = filters
        self.list_filters = list_filters
        self.list_kernel_size = list_kernel_size
        self.list_padding = list_padding
        self.list_strides = list_strides

    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                # list_v = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
                # for v in list_v:
                #     print v
                # print
                # print
                # for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS):
                #     print v
                # import ipdb; ipdb.set_trace()
                scope.reuse_variables()

            # Store all layers in a dict
            d = collections.OrderedDict()

            # Initial dense multiplication
            x = layers.linear(x, self.filters * self.start_dim * self.start_dim)

            # Reshape to image format
            if self.data_format == "NCHW":
                target_shape = (self.batch_size, self.filters, self.start_dim, self.start_dim)
            else:
                target_shape = (self.batch_size, self.start_dim, self.start_dim, self.filters)

            x = layers.reshape(x, target_shape)
            x = tf.contrib.layers.batch_norm(x, fused=True)
            x = tf.nn.relu(x)

            # # Conv2D + Phase shift blocks
            # x = layers.conv2d_block("conv2D_1_1", x, 512, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=True, activation_fn=layers.lrelu)
            # x = layers.conv2d_block("conv2D_1_2", x, 512, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.phase_shift(x, upsampling_factor=2, data_format=self.data_format, name="PS1")

            # x = layers.conv2d_block("conv2D_2_1", x, 256, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.conv2d_block("conv2D_2_2", x, 256, 3, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bias=False, bn=False, activation_fn=layers.lrelu)
            # x = layers.phase_shift(x, upsampling_factor=2, data_format=self.data_format, name="PS2")

            # x = layers.conv2d_block("conv2D_3", x, 1, 1, 1, p="SAME", stddev=0.02,
            #                         data_format=self.data_format, bn=False)

            # # Upsampling2D + conv blocks
            for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
                name = "upsample2D_%s" % idx
                if idx == len(self.list_filters) - 1:
                    bn = False
                    activation_fn = None
                else:
                    bn = True
                    activation_fn = tf.nn.relu
                x = layers.upsample2d_block(name, x, f, k, s, p, data_format=self.data_format, bn=bn, activation_fn=activation_fn)

            # # Transposed conv blocks
            # for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
            #     img_size = self.start_dim * (2 ** (idx + 1))
            #     if self.data_format == "NCHW":
            #         output_shape = (self.batch_size, f, img_size, img_size)
            #     else:
            #         output_shape = (self.batch_size, img_size, img_size, f)
            #     name = "deconv2D_%s" % idx
            #     if idx == len(self.list_filters) - 1:
            #         bn = False
            #         activation_fn = None
            #     else:
            #         bn = True
            #         activation_fn = layers.lrelu
            #     x = layers.deconv2d_block(name, x, output_shape, k, s, p, data_format=self.data_format, bn=bn, activation_fn=activation_fn)

            x = tf.nn.tanh(x, name="X_G")

            return x


class Discriminator(Model):
    def __init__(self, list_filters, list_kernel_size, list_strides, list_padding, batch_size,
                 name="discriminator", data_format="NCHW"):
        # Determine data format from output shape

        super(Discriminator, self).__init__(name)

        self.data_format = data_format
        self.name = name
        self.list_filters = list_filters
        self.list_strides = list_strides
        self.list_kernel_size = list_kernel_size
        self.batch_size = batch_size
        self.list_padding = list_padding

    def __call__(self, x, reuse=False):
        with tf.variable_scope(self.name) as scope:

            if reuse:
                scope.reuse_variables()

            for idx, (f, k, s, p) in enumerate(zip(self.list_filters, self.list_kernel_size, self.list_strides, self.list_padding)):
                if idx == 0:
                    bn = False
                else:
                    bn = True
                name = "conv2D_%s" % idx
                x = layers.conv2d_block(name, x, f, k, s, p=p, stddev=0.02,
                                        data_format=self.data_format, bias=True, bn=bn, activation_fn=layers.lrelu)

            target_shape = (self.batch_size, -1)
            x = layers.reshape(x, target_shape)

            # # Add MBD
            # x_mbd = layers.mini_batch_disc(x, num_kernels=100, dim_per_kernel=5)
            # # Concat
            # x = tf.concat([x, x_mbd], axis=1)

            x = layers.linear(x, 1, bias=False)

            return x


## train and test the models

In [None]:
import os
import sys
import models
from tqdm import tqdm
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

def train_model():
    # Setup session
    sess = setup_session()
    # Setup async input queue of real images
    X_real = input_data(sess)
    #######################
    # Instantiate generator
    #######################
    list_filters = [256, 128, 64, 3]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    output_shape = X_real.get_shape().as_list()[1:]
    G = models.Generator(list_filters, list_kernel_size, list_strides, list_padding, output_shape,
                         batch_size=FLAGS.batch_size, data_format=FLAGS.data_format)

    ###########################
    # Instantiate discriminator
    ###########################
    list_filters = [32, 64, 128, 256]
    list_strides = [2] * len(list_filters)
    list_kernel_size = [3] * len(list_filters)
    list_padding = ["SAME"] * len(list_filters)
    D = models.Discriminator(list_filters, list_kernel_size, list_strides, list_padding,
                             FLAGS.batch_size, data_format=FLAGS.data_format)
    ###########################
    # Instantiate optimizers
    ###########################
    G_opt = tf.train.AdamOptimizer(learning_rate=1E-4, name='G_opt', beta1=0.5, beta2=0.9)
    D_opt = tf.train.AdamOptimizer(learning_rate=1E-4, name='D_opt', beta1=0.5, beta2=0.9)
    ###########################
    # Instantiate model outputs
    ###########################
    # noise_input = tf.random_normal((FLAGS.batch_size, FLAGS.noise_dim,), stddev=0.1)
    noise_input = tf.random_uniform((FLAGS.batch_size, FLAGS.noise_dim,), minval=-1, maxval=1)
    X_fake = G(noise_input)
    # output images
    X_G_output = du.unnormalize_image(X_fake)
    X_real_output = du.unnormalize_image(X_real)

    D_real = D(X_real)
    D_fake = D(X_fake, reuse=True)

    ###########################
    # Instantiate losses
    ###########################
    G_loss = -tf.reduce_mean(D_fake)
    D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)
    epsilon = tf.random_uniform(
        shape=[FLAGS.batch_size, 1, 1, 1],
        minval=0.,
        maxval=1.
    )
    X_hat = X_real + epsilon * (X_fake - X_real)
    D_X_hat = D(X_hat, reuse=True)
    grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0]
    if FLAGS.data_format == "NCHW":
        red_idx = [1]
    else:
        red_idx = [-1]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx))
    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
    D_loss += 10 * gradient_penalty

    ###########################
    # Compute gradient updates
    ###########################

    dict_G_vars = G.get_trainable_variables()
    G_vars = [dict_G_vars[k] for k in dict_G_vars.keys()]

    dict_D_vars = D.get_trainable_variables()
    D_vars = [dict_D_vars[k] for k in dict_D_vars.keys()]

    G_gradvar = G_opt.compute_gradients(G_loss, var_list=G_vars, colocate_gradients_with_ops=True)
    G_update = G_opt.apply_gradients(G_gradvar, name='G_loss_minimize')

    D_gradvar = D_opt.compute_gradients(D_loss, var_list=D_vars, colocate_gradients_with_ops=True)
    D_update = D_opt.apply_gradients(D_gradvar, name='D_loss_minimize')

    ##########################
    # Group training ops
    ##########################
    loss_ops = [G_loss, D_loss]
    ##########################
    # Summary ops
    ##########################
    # Add summary for gradients
    tu.add_gradient_summary(G_gradvar)
    tu.add_gradient_summary(D_gradvar)

    # Add scalar symmaries
    tf.summary.scalar("G loss", G_loss)
    tf.summary.scalar("D loss", D_loss)
    tf.summary.scalar("gradient_penalty", gradient_penalty)

    summary_op = tf.summary.merge_all()

    ############################
    # Start training
    ############################

    # Initialize session
    saver = initialize_session(sess)
    # Start queues
    manage_queues(sess)

    # Summaries
    writer = manage_summaries(sess)

    for e in tqdm(range(FLAGS.nb_epoch), desc="Training progress"):

        t = tqdm(range(FLAGS.nb_batch_per_epoch), desc="Epoch %i" % e, mininterval=0.5)
        for batch_counter in t:

            for di in range(5):
                sess.run([D_update])

            output = sess.run([G_update] + loss_ops + [summary_op])

            if batch_counter % (FLAGS.nb_batch_per_epoch // 20) == 0:
                writer.add_summary(output[-1], e * FLAGS.nb_batch_per_epoch + batch_counter)

            t.set_description('Epoch %i' % e)

        # Plot some generated images
        output = sess.run([X_G_output, X_real_output])
        vu.save_image(output, FLAGS.data_format, e)

        # Save session
        saver.save(sess, os.path.join(FLAGS.model_dir, "model"), global_step=e)

    print('Finished training!')


### paprameters

In [None]:

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS


def define_flags():

    ############
    # Run mode
    ############
    tf.app.flags.DEFINE_string('run', None, "Which operation to run. [train|inference]")

    ##########################
    # Training parameters
    ###########################
    tf.app.flags.DEFINE_integer('nb_epoch', 400, "Number of epochs")
    tf.app.flags.DEFINE_integer('batch_size', 256, "Number of samples per batch.")
    tf.app.flags.DEFINE_integer('nb_batch_per_epoch', 50, "Number of batches per epoch")
    tf.app.flags.DEFINE_float('learning_rate', 2E-4, "Learning rate used for AdamOptimizer")
    tf.app.flags.DEFINE_integer('noise_dim', 100, "Noise dimension for GAN generation")
    tf.app.flags.DEFINE_integer('random_seed', 0, "Seed used to initialize rng.")

    ############################################
    # General tensorflow parameters parameters
    #############################################
    tf.app.flags.DEFINE_bool('use_XLA', False, "Whether to use XLA compiler.")
    tf.app.flags.DEFINE_integer('num_threads', 2, "Number of threads to fetch the data")
    tf.app.flags.DEFINE_float('capacity_factor', 32, "Nuumber of batches to store in queue")

    ##########
    # Datasets
    ##########
    tf.app.flags.DEFINE_string('data_format', "NHWC", "Tensorflow image data format.")
    tf.app.flags.DEFINE_string('celebA_path', "../../data/raw/img_align_celeba", "Path to celebA images")
    tf.app.flags.DEFINE_integer('channels', 3, "Number of channels")
    tf.app.flags.DEFINE_float('central_fraction', 0.8, "Central crop as a fraction of total image")
    tf.app.flags.DEFINE_integer('img_size', 64, "Image size")

    ##############
    # Directories
    ##############
    tf.app.flags.DEFINE_string('model_dir', '../../models', "Output folder where checkpoints are dumped.")
    tf.app.flags.DEFINE_string('log_dir', '../../logs', "Logs for tensorboard.")
    tf.app.flags.DEFINE_string('fig_dir', '../../figures', "Where to save figures.")
    tf.app.flags.DEFINE_string('raw_dir', '../../data/raw', "Where raw data is saved")
    tf.app.flags.DEFINE_string('data_dir', '../../data/processed', "Where processed data is saved")


In [None]:
import os
# Disable Tensorflow's INFO and WARNING messages
# See http://stackoverflow.com/questions/35911252
if 'TF_CPP_MIN_LOG_LEVEL' not in os.environ:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import flags
import tensorflow as tf
import train_wgan_GP

FLAGS = tf.app.flags.FLAGS


def launch_training():

    train_wgan_GP.train_model()


def main(argv=None):

    assert FLAGS.run in ["train", "inference"], "Choose [train|inference]"

    if FLAGS.run == 'train':
        launch_training()


if __name__ == '__main__':
    flags.define_flags()
    tf.app.run()


# use the models