## Trying BEGAN in Faces image

https://www.reddit.com/r/MachineLearning/comments/633jal/r170310717_began_boundary_equilibrium_generative/

https://arxiv.org/pdf/1703.10717.pdf

1. The generator wins by generating images that the discriminator can successfully autoencode with small loss.

2. The discriminator wins by autoencoding real images well and by autoencoding the generated images poorly.


Auto-Encoder

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

awesome python library
https://pypi.python.org/pypi/tqdm

https://github.com/carpedm20/BEGAN-tensorflow/blob/master/models.py

https://arxiv.org/pdf/1703.10717.pdf


### Library and Utils / Helper functions

In [1]:
#Import the libraries we will need.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.contrib.slim as slim
import os
import re
import scipy.misc
import scipy
from functools import partial
from collections import OrderedDict
import glob
import traceback

In [2]:
def mkdir(paths):
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    for path in paths:
        path_dir, _ = os.path.split(path)
        if not os.path.isdir(path_dir):
            os.makedirs(path_dir)


def session(graph=None, allow_soft_placement=True,
            log_device_placement=False, allow_growth=True):
    """ return a Session with simple config """

    config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
                            log_device_placement=log_device_placement)
    config.gpu_options.allow_growth = allow_growth
    return tf.Session(graph=graph, config=config)


def tensors_filter(tensors, filters, combine_type='or'):
    assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!'
    assert isinstance(filters, (str, list, tuple)), \
        '`filters` should be a string or a list(tuple) of strings!'
    assert combine_type == 'or' or combine_type == 'and', "`combine_type` should be 'or' or 'and'!"

    if isinstance(filters, str):
        filters = [filters]

    f_tens = []
    for ten in tensors:
        if combine_type == 'or':
            for filt in filters:
                if filt in ten.name:
                    f_tens.append(ten)
                    break
        elif combine_type == 'and':
            all_pass = True
            for filt in filters:
                if filt not in ten.name:
                    all_pass = False
                    break
            if all_pass:
                f_tens.append(ten)
    return f_tens


def trainable_variables(filters=None, combine_type='or'):
    t_var = tf.trainable_variables()
    if filters is None:
        return t_var
    else:
        return tensors_filter(t_var, filters, combine_type)


def summary(tensor_collection, summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram']):
    """
    usage:
    1. summary(tensor)
    2. summary([tensor_a, tensor_b])
    3. summary({tensor_a: 'a', tensor_b: 'b})
    """

    def _summary(tensor, name, summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram']):
        """ Attach a lot of summaries to a Tensor. """

        if name is None:
            # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
            # session. This helps the clarity of presentation on tensorboard.
            name = re.sub('%s_[0-9]*/' % 'tower', '', tensor.name)
            name = re.sub(':', '-', name)

        with tf.name_scope('summary_' + name):
            summaries = []
            if len(tensor._shape) == 0:
                summaries.append(tf.summary.scalar(name, tensor))
            else:
                if 'mean' in summary_type:
                    mean = tf.reduce_mean(tensor)
                    summaries.append(tf.summary.scalar(name + '/mean', mean))
                if 'stddev' in summary_type:
                    mean = tf.reduce_mean(tensor)
                    stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean)))
                    summaries.append(tf.summary.scalar(name + '/stddev', stddev))
                if 'max' in summary_type:
                    summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor)))
                if 'min' in summary_type:
                    summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor)))
                if 'sparsity' in summary_type:
                    summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor)))
                if 'histogram' in summary_type:
                    summaries.append(tf.summary.histogram(name, tensor))
            return tf.summary.merge(summaries)

    if not isinstance(tensor_collection, (list, tuple, dict)):
        tensor_collection = [tensor_collection]
    with tf.name_scope('summaries'):
        summaries = []
        if isinstance(tensor_collection, (list, tuple)):
            for tensor in tensor_collection:
                summaries.append(_summary(tensor, None, summary_type))
        else:
            for tensor, name in tensor_collection.items():
                summaries.append(_summary(tensor, name, summary_type))
        return tf.summary.merge(summaries)



def load_checkpoint(checkpoint_dir, session, var_list=None):
    print(' [*] Loading checkpoint...')
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        ckpt_path = os.path.join(checkpoint_dir, ckpt_name)
    try:
        restorer = tf.train.Saver(var_list)
        restorer.restore(session, ckpt_path)
        print(' [*] Loading successful! Copy variables from % s' % ckpt_path)
        return True
    except:
        print(' [*] No suitable checkpoint!')
        return False



def disk_image_batch(image_paths, batch_size, shape, preprocess_fn=None, shuffle=True, num_threads=16,
                     min_after_dequeue=100, allow_smaller_final_batch=False, scope=None):
    """
    This function is suitable for bmp, jpg, png and gif files
    image_paths: string list or 1-D tensor, each of which is an iamge path
    preprocess_fn: single image preprocessing function
    """

    with tf.name_scope(scope, 'disk_image_batch'):
        data_num = len(image_paths)

        # dequeue a single image path and read the image bytes; enqueue the whole file list
        _, img = tf.WholeFileReader().read(tf.train.string_input_producer(image_paths, shuffle=shuffle, capacity=data_num))
        img = tf.image.decode_image(img)

        # preprocessing
        img.set_shape(shape)
        if preprocess_fn is not None:
            img = preprocess_fn(img)

        # batch datas
        if shuffle:
            capacity = min_after_dequeue + (num_threads + 1) * batch_size
            img_batch = tf.train.shuffle_batch([img],
                                               batch_size=batch_size,
                                               capacity=capacity,
                                               min_after_dequeue=min_after_dequeue,
                                               num_threads=num_threads,
                                               allow_smaller_final_batch=allow_smaller_final_batch)
        else:
            img_batch = tf.train.batch([img],
                                       batch_size=batch_size,
                                       allow_smaller_final_batch=allow_smaller_final_batch)

        return img_batch, data_num


class DiskImageData:

    def __init__(self, image_paths, batch_size, shape, preprocess_fn=None, shuffle=True, num_threads=16,
                 min_after_dequeue=100, allow_smaller_final_batch=False, scope=None):
        """
        This function is suitable for bmp, jpg, png and gif files
        image_paths: string list or 1-D tensor, each of which is an iamge path
        preprocess_fn: single image preprocessing function
        """

        self.graph = tf.Graph()  # declare ops in a separated graph
        with self.graph.as_default():
            # @TODO
            # There are some strange errors if the gpu device is the
            # same with the main graph, but cpu device is ok. I don't know why...
            with tf.device('/cpu:0'):
                self._batch_ops, self._data_num = disk_image_batch(image_paths, batch_size, shape, preprocess_fn, shuffle, num_threads,
                                                                   min_after_dequeue, allow_smaller_final_batch, scope)

        print(' [*] DiskImageData: create session!')
        self.sess = session(graph=self.graph)
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)

    def __len__(self):
        return self._data_num

    def batch(self):
        return self.sess.run(self._batch_ops)

    def __del__(self):
        print(' [*] DiskImageData: stop threads and close session!')
        self.coord.request_stop()
        self.coord.join(self.threads)
        self.sess.close()


def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
    """
    transform images from [-1.0, 1.0] to [min_value, max_value] of dtype
    """
    assert \
        np.min(images) >= -1.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \
        and (images.dtype == np.float32 or images.dtype == np.float64), \
        'The input images should be float64(32) and in the range of [-1.0, 1.0]!'
    if dtype is None:
        dtype = images.dtype
    return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)


def imwrite(image, path):
    """ save an [-1.0, 1.0] image """

    if image.ndim == 3 and image.shape[2] == 1:  # for gray image
        image = np.array(image, copy=True)
        image.shape = image.shape[0:2]
    return scipy.misc.imsave(path, to_range(image, 0, 255, np.uint8))


def immerge(images, row, col):
    """
    merge images into an image with (row * h) * (col * w)
    `images` is in shape of N * H * W(* C=1 or 3)
    """

    h, w = images.shape[1], images.shape[2]
    if images.ndim == 4:
        img = np.zeros((h * row, w * col, images.shape[3]))
    elif images.ndim == 3:
        img = np.zeros((h * row, w * col))
    for idx, image in enumerate(images):
        i = idx % col
        j = idx // col
        img[j * h:j * h + h, i * w:i * w + w, ...] = image

    return img

In [3]:
def flatten_fully_connected(inputs,
                            num_outputs,
                            activation_fn=tf.nn.relu,
                            normalizer_fn=None,
                            normalizer_params=None,
                            weights_initializer=slim.xavier_initializer(),
                            weights_regularizer=None,
                            biases_initializer=tf.zeros_initializer(),
                            biases_regularizer=None,
                            reuse=None,
                            variables_collections=None,
                            outputs_collections=None,
                            trainable=True,
                            scope=None):
    with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]):
        if inputs.shape.ndims > 2:
            inputs = slim.flatten(inputs)
        return slim.fully_connected(inputs,
                                    num_outputs,
                                    activation_fn,
                                    normalizer_fn,
                                    normalizer_params,
                                    weights_initializer,
                                    weights_regularizer,
                                    biases_initializer,
                                    biases_regularizer,
                                    reuse,
                                    variables_collections,
                                    outputs_collections,
                                    trainable,
                                    scope)


def leak_relu(x, leak, scope=None):
    with tf.name_scope(scope, 'leak_relu', [x, leak]):
        if leak < 1:
            y = tf.maximum(x, leak * x)
        else:
            y = tf.minimum(x, leak * x)
        return y
    
def l1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

### Layer wrappers

In [4]:
'''
slim layers:

    conv2d(args):
        inputs: 
            A 4-D tensor with dimensions [batch_size, height, width, channels]
        num_outputs:
            Integer, the number of output filters.
        kernel_size:
            Can be an int if both values are the same.
        stride: 
            default=1
        ...
    
    conv2d_transpose(args):
        inputs:
        num_outputs: 
            Integer, the number of output filters.
        kernel_size:
        stride: 
            default=1
        ...
        
    batch_norm(args):
        is_training: 
            Whether or not the layer is in training mode. In training mode
            it would accumulate the statistics of the moments into `moving_mean` and
            `moving_variance` using an exponential moving average with the given
            `decay`. When it is not in training mode then it would use the values of
            the `moving_mean` and the `moving_variance`.
            
'''

conv = partial(slim.conv2d, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=0.02))
dconv = partial(slim.conv2d_transpose, activation_fn=None, weights_initializer=tf.random_normal_initializer(stddev=0.02))
fc = partial(flatten_fully_connected, activation_fn=None, weights_initializer=tf.random_normal_initializer(stddev=0.02))
relu = tf.nn.relu
elu = tf.nn.elu
lrelu = partial(leak_relu, leak=0.2)
batch_norm = partial(slim.batch_norm, decay=0.9, scale=True, epsilon=1e-5, updates_collections=None)
ln = slim.layer_norm

In [5]:
def generator(z, dim=128, reuse=True, training=True):
    bn = partial(batch_norm, is_training=training)
    dconv_bn_elu = partial(dconv, normalizer_fn=bn, activation_fn=elu, biases_initializer=None)

    with tf.variable_scope('generator', reuse=reuse):
        y = fc(z, 4 * 4 * dim * 8)
        y = tf.reshape(y, [-1, 4, 4, dim * 8])
        #y = dconv_bn_elu(y, dim * 4, 3, 1)
        y = dconv_bn_elu(y, dim * 4, 3, 2)
        #y = dconv_bn_elu(y, dim * 2, 3, 1)
        y = dconv_bn_elu(y, dim * 2, 3, 2)
        #y = dconv_bn_elu(y, dim * 1, 3, 1)
        y = dconv_bn_elu(y, dim * 1, 3, 2)
        img = tf.tanh(dconv(y, 3, 3, 2)) # Change number of out channels
        return img


def discriminator(img, dim=128, reuse=True, training=True):
    bn = partial(batch_norm, is_training=training)
    conv_bn_elu = partial(conv, normalizer_fn=bn, activation_fn=elu, biases_initializer=None)
    dconv_bn_elu = partial(dconv, normalizer_fn=bn, activation_fn=elu, biases_initializer=None)
    h = 64 # or 128
    
    with tf.variable_scope('discriminator', reuse=reuse):
        encoder = elu(conv(img, dim, 3, 2))
        #encoder = conv_bn_elu(encoder, dim, 3, 1)
        encoder = conv_bn_elu(encoder, dim*2, 3, 2)
        #encoder = conv_bn_elu(encoder, dim*2, 3, 1)
        encoder = conv_bn_elu(encoder, dim*4, 3, 2)
        #encoder = conv_bn_elu(encoder, dim*4, 3, 1)
        encoder = conv_bn_elu(encoder, dim*8, 3, 2)
        #encoder = conv_bn_elu(encoder, dim*8, 3, 1)
        latent_space = fc(encoder, h)
        decoder = fc(latent_space, 4 * 4 * dim * 8)
        decoder = tf.reshape(decoder, [-1, 4, 4, dim * 8])
        #decoder = dconv_bn_elu(decoder, dim * 4, 3, 1)
        decoder = dconv_bn_elu(decoder, dim * 4, 3, 2)
        #decoder = dconv_bn_elu(decoder, dim * 2, 3, 1)
        decoder = dconv_bn_elu(decoder, dim * 2, 3, 2)
        #decoder = dconv_bn_elu(decoder, dim * 1, 3, 1)
        decoder = dconv_bn_elu(decoder, dim * 1, 3, 2)
        reconst = tf.tanh(dconv(decoder, 3, 3, 2))
        return reconst

### Connecting them together

In [None]:
""" param """
epoch = 50000
batch_size = 32 # 16 from paper
lr = 0.0001
z_dim = 64
clip = 0.01
n_critic = 5
gpu_id = 0

lamda = 0.001
gamma = 0.5
beta1 = 0.9 #0.5

# The number of run
run_num = 3

''' data '''
# you should prepare your own data in ./data/img_align_celeba
# celeba original size is [218, 178, 3]


def preprocess_fn(img):
    #Transform it to be between -1 and 1
    img = tf.to_float(img) / 127.5 - 1

    return img

file_paths = glob.glob("./Face_data/Celeba/dataset_64x64/*")
data_pool = DiskImageData(file_paths, batch_size, shape=[64, 64, 3], preprocess_fn=preprocess_fn)


""" graphs """
with tf.device('/gpu:%d' % gpu_id):

    ''' graph '''
    # inputs
    real = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])
    z = tf.placeholder(tf.float32, shape=[None, z_dim])
    k_t = tf.placeholder(tf.float32, shape=(), name='kt') # shape of a scalar = ()

    # generate
    fake = generator(z, reuse=False)

    # discriminate
    reconst_r = discriminator(real, reuse=False)
    reconst_f = discriminator(fake)

    # losses
    AE_loss_r = l1_loss(real, reconst_r)
    AE_loss_f = l1_loss(fake, reconst_f)
    d_loss = AE_loss_r - k_t * AE_loss_f
    g_loss = AE_loss_f
    M_global = AE_loss_r + tf.abs(gamma * AE_loss_r - AE_loss_f)

    # vars
    d_vars = trainable_variables('discriminator')
    g_vars = trainable_variables('generator')
    
    # optimizers
    d_opt = tf.train.AdamOptimizer(lr, beta1).minimize(d_loss, var_list=d_vars)
    g_opt = tf.train.AdamOptimizer(lr, beta1).minimize(g_loss, var_list=g_vars)
    
    # summaries
    d_summary = summary({AE_loss_r: 'AE_real_reconst_loss', d_loss: 'd_loss'})
    g_summary = summary({g_loss: 'g_loss'})
    m_summary = summary({M_global: 'M_global'})

    # sample
    f_sample = generator(z, training=False)


 [*] DiskImageData: create session!


### Training the network

In [None]:
""" train """
''' init '''
# session
sess = session()
# saver
saver = tf.train.Saver(max_to_keep=5)
# summary writer
summary_writer = tf.summary.FileWriter('./BEGAN/summaries/run%s' % run_num, sess.graph)

''' initialization '''
ckpt_dir = './BEGAN/checkpoints/run%s' % run_num
mkdir(ckpt_dir + '/')
if not load_checkpoint(ckpt_dir, sess):
    sess.run(tf.global_variables_initializer())

''' train '''
try:
    #z_ipt_sample = np.random.normal(size=[100, z_dim])
    z_ipt_sample = np.random.uniform(-1., 1., size=[100, z_dim])
    kt = np.float32(0.)
    
    batch_epoch = len(data_pool) // (batch_size * n_critic)
    max_it = epoch * batch_epoch

    for it in range(max_it):

        # which epoch
        epoch = it // batch_epoch
        it_epoch = it % batch_epoch + 1

        # batch data
        real_ipt = data_pool.batch()
        #z_ipt = np.random.normal(size=[batch_size, z_dim])
        z_ipt = np.random.uniform(-1., 1., size=[batch_size, z_dim])

        # terms to be calculated & feed list
        g_opt_list = [g_opt, g_loss, AE_loss_r, AE_loss_f]
        d_opt_list = [d_opt, d_loss, d_summary, g_summary, m_summary]
        feed_dict = {real: real_ipt, z: z_ipt, k_t: kt}

        # run
        _, g_loss_calculated, AE_loss_r_calculated, AE_loss_f_calculated = sess.run(g_opt_list, feed_dict=feed_dict)
        _, d_loss_calculated, summary_d, summary_g, summary_m = sess.run(d_opt_list, feed_dict=feed_dict)
        
        # update kt, m_global -- (The range of kt is [0,1])
        kt = kt + lamda * (gamma * AE_loss_r_calculated - AE_loss_f_calculated)
        kt = np.maximum(np.minimum(1.,kt), 0.)
        m_global = AE_loss_r_calculated + np.abs(gamma * AE_loss_r_calculated - AE_loss_f_calculated)
        
        # write train summary
        summary_writer.add_summary(summary_d, it)
        summary_writer.add_summary(summary_g, it)
        summary_writer.add_summary(summary_m, it)

        

        # display
        if it % 100 == 0:
            print("Epoch: (%3d) (%5d/%5d) -- g_loss: %.4f, d_loss: %.4f AE_r: %.4f, AE_f: %.4f, kt: %.8f, M: %.8f" 
                  % (epoch, it_epoch, batch_epoch, 
                     g_loss_calculated, d_loss_calculated,
                     AE_loss_r_calculated, AE_loss_f_calculated,
                     kt, m_global))

        # save
        if (it + 1) % 1000 == 0:
            save_path = saver.save(sess, '%s/Epoch_(%d)_(%dof%d).ckpt' % (ckpt_dir, epoch, it_epoch, batch_epoch))
            print('Model saved in file: % s' % save_path)

        # sample
        if (it + 1) % 100 == 0:
            f_sample_opt = sess.run(f_sample, feed_dict={z: z_ipt_sample})

            save_dir = './BEGAN/sample_images_while_training/run%s' % run_num
            mkdir(save_dir + '/')
            imwrite(immerge(f_sample_opt, 10, 10), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, epoch, it_epoch, batch_epoch))

except Exception as e:
    traceback.print_exc()
finally:
    print(" [*] Close main session!")
    sess.close()

 [*] Loading checkpoint...
 [*] No suitable checkpoint!
Epoch: (  0) (    1/ 1208) -- g_loss: 0.2654, d_loss: 0.4872 AE_r: 0.4872, AE_f: 0.2654, kt: 0.00000000, M: 0.50900128
Epoch: (  0) (  101/ 1208) -- g_loss: 0.2617, d_loss: 0.1873 AE_r: 0.1873, AE_f: 0.2617, kt: 0.00000000, M: 0.35536502


### Using a trained network