StarGAN implementation v2

https://arxiv.org/abs/1711.09020
https://github.com/goldkim92/StarGAN-tensorflow/blob/master/
https://github.com/ly-atdawn/StarGAN-Tensorflow
https://github.com/igul222/improved_wgan_training/blob/master/

In [1]:
import os
import scipy.misc as scm
import numpy as np

def make_project_dir(project_dir):
    if not os.path.exists(project_dir):
        os.makedirs(project_dir)
        os.makedirs(os.path.join(project_dir, 'models'))
        os.makedirs(os.path.join(project_dir, 'result'))
        os.makedirs(os.path.join(project_dir, 'result_test'))


def get_image(img_path, flip=False): # [0,255] to [-1,1]
    img = scm.imread(img_path) 
    if flip:
        img = np.fliplr(img)
    img = img * 2. /255. - 1.
    img = img[..., ::-1]  # rgb to bgr
    return img

def get_label(path, size):
    label = int(path[-5])
    one_hot = np.zeros(size)
    one_hot[ label ] = 1.0
    one_hot[ one_hot==0 ] = 0.0
    return one_hot

def inverse_image(img): # [-1,1] to [0,255]
    img = (img + 1.) / 2. * 255.
    img[img > 255] = 255
    img[img < 0] = 0
    img = img[..., ::-1] # bgr to rgb
    return img

def pair_expressions(paths):
    subject_exprs = []
    subject_pairs = []
    all_pairs = []
    last_subject = 0

    # Pair all expression of a subject
    for path in paths:
        subject = int(path[-10:-6])

        if subject != last_subject and last_subject != 0:
            subject_pairs = [(x, y) for x in subject_exprs for y in subject_exprs]
            all_pairs.extend(subject_pairs)
            subject_exprs = []

        subject_exprs.append(path)
        last_subject = subject

    # Last subject
    subject_pairs = [(x, y) for x in subject_exprs for y in subject_exprs]
    all_pairs.extend(subject_pairs)
    return all_pairs

def get_shape_c(tensor): # static shape
    return tensor.get_shape().as_list()


In [2]:
import tensorflow as tf
import numpy as np

cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits

def conv(x, filter_shape, bias=True, stride=1, padding="VALID", name="conv2d"):
    kw, kh, nin, nout = filter_shape

    stddev = np.sqrt(2.0/(np.sqrt(nin*nout)*kw*kh))
    k_initializer = tf.truncated_normal_initializer(stddev=0.02)
    
    with tf.variable_scope(name):
        x = tf.layers.conv2d(x, filters=nout, kernel_size=(kw, kh), strides=(stride, stride), padding=padding, 
                             use_bias=bias, kernel_initializer=k_initializer)
    return x

def deconv(x, filter_shape, bias=True, stride=1, padding="VALID", name="conv2d_transpose"):
    kw, kh, nin, nout = filter_shape

    stddev = np.sqrt(1.0/(np.sqrt(nin*nout)*kw*kh))
    k_initializer = tf.truncated_normal_initializer(stddev=0.02)
    with tf.variable_scope(name):
        x = tf.layers.conv2d_transpose(x, filters=nout, kernel_size=(kw, kh), strides=(stride, stride), padding=padding, 
                                       use_bias=bias, kernel_initializer=k_initializer)
    return x

def fc(x, output_shape, bias=True, name='fc'):
    shape = x.get_shape().as_list()
    dim = np.prod(shape[1:])
    x = tf.reshape(x, [-1, dim])
    input_shape = dim

    stddev = np.sqrt(1.0/(np.sqrt(input_shape*output_shape)))
    initializer = tf.random_normal_initializer(stddev=stddev)
    with tf.variable_scope(name):
        weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer)
        x = tf.matmul(x, weight)

        if bias:
            b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.))
            x = tf.nn.bias_add(x, b)
    return x


def pool(x, r=2, s=1):
    return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME")

def instance_norm(input, name='instance_norm'):
    with tf.variable_scope(name):
        depth = input.get_shape()[3]
        scale = tf.get_variable('scale', [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0))
        mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True)
        epsilon = 1e-5
        inv = tf.rsqrt(variance + epsilon)
        normalized = (input-mean)*inv
        return scale*normalized + offset

def l1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

def l2_loss(x, y):
    return tf.reduce_mean(tf.square(x - y))

def resize_nn(x, size):
    return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size)))

def lrelu(x, leak=0.01, name='lrelu'): #lrelu(x, leak=0.2, name='lrelu'):
    return tf.maximum(x, leak*x)

def gradient_penalty(real, fake, f):
        def interpolate(a, b):
            shape = tf.concat((tf.shape(a)[0:1], tf.tile([1], [a.shape.ndims - 1])), axis=0)
            alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.)
            inter = a + alpha * (b - a)
            inter.set_shape(a.get_shape().as_list())
            return inter

        x = interpolate(real, fake)
        pred, _ = f(x, reuse=True)
        gradients = tf.gradients(pred, x)[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=3))
        gp = tf.reduce_mean((slopes - 1.)**2)
        return gp

In [3]:
import glob
import time
import numpy as np
import tensorflow as tf

class op_base:
    def __init__(self, sess, project_name):
        self.sess = sess

        # Train
        self.flag = True #args.flag
        self.gpu_number = 0 #args.gpu_number
        self.project = project_name #"test_began" #args.project

        # Train Data
        self.data_dir = "./Face_data/Faces_with_expression_label/dataset_64x64" #args.data_dir #./Data
        self.dataset = "expr" #args.dataset  # celeba
        self.data_size = 64 #args.data_size  # 64 or 128
        self.data_opt = "crop" #args.data_opt  # raw or crop
        self.data_label_vector_size = 7 #size of one-hot-encoded label vector

        # Train Iteration
        self.niter = 200 #50 #args.niter
        self.niter_snapshot = 500 #args.nsnapshot
        self.max_to_keep = 5 #args.max_to_keep

        # Train Parameter
        self.batch_size = 16 #args.batch_size
        self.learning_rate = 1e-4 #args.learning_rate
        self.mm = 0.5 #args.momentum
        self.mm2 = 0.999 #args.momentum2
        self.lamda = 0.001 #args.lamda
        self.gamma = 0.5 #args.gamma
        self.filter_number = 64 #args.filter_number
        self.input_size = 64 #args.input_size
        self.embedding = 128 #64 #args.embedding
        
        self.lambda_cls = 1.
        self.lambda_recon = 10.
        self.lambda_gp = 10.

        

        # Result Dir & File
        self.project_dir = 'assets_ae/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size)
        self.ckpt_dir = os.path.join(self.project_dir, 'models')
        self.model_name = "{0}.model".format(self.project)
        self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name)

        # etc.
        if not os.path.exists('assets_ae'):
            os.makedirs('assets_ae')
        make_project_dir(self.project_dir)

    def load(self, sess, saver, ckpt_dir):
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(ckpt_dir, ckpt_name))

In [4]:
import glob
import time
import datetime
import random


class Operator(op_base):
    def __init__(self, sess, project_name):
        op_base.__init__(self, sess, project_name)
        self.build_model()

    def build_model(self):
        # Input placeholder
        self.x = tf.placeholder(tf.float32, shape=[None, self.data_size, self.data_size, 3], name='x')
        self.x_c = tf.placeholder(tf.float32, shape=[None, self.data_label_vector_size], name='x_c')
        self.target_c = tf.placeholder(tf.float32, shape=[None, self.data_label_vector_size], name='target_c')
        self.alpha = tf.placeholder(tf.float32, shape=[None, 1], name='alpha')
        self.lr = tf.placeholder(tf.float32, name='lr')

        # Generator
        self.G_f = self.generator(self.x, self.target_c)
        self.G_recon = self.generator(self.G_f, self.x_c, reuse=True)
        
        self.G_test = self.generator(self.x, self.target_c, reuse=True)
        
        # Discriminator
        self.D_f, self.D_f_cls = self.discriminator(self.G_f)
        self.D_x, self.D_x_cls = self.discriminator(self.x, reuse=True)
        
        # Gradient Penalty
        #self.gp = gradient_penalty(self.x, self.G_f, self.discriminator)
        self.real_data = tf.reshape(self.x, [-1, self.data_size*self.data_size*3])
        self.fake_data = tf.reshape(self.G_f, [-1, self.data_size*self.data_size*3])
        self.diff = self.fake_data - self.real_data
        self.interpolate = self.real_data + self.alpha*self.diff
        
        self.inter_reshape = tf.reshape(self.interpolate, [-1, self.data_size, self.data_size, 3])
        self.G_inter, _ = self.discriminator(self.inter_reshape, reuse=True)
        
#         self.G_inter = tf.reshape(self.x, [-1, ])
        
#         self.G_inter = self.discriminator(self.interpolate, reuse=True)
        self.grad = tf.gradients(self.G_inter, 
                                 xs=[self.inter_reshape])[0]
        self.slopes = tf.sqrt(tf.reduce_sum(tf.square(self.grad), axis=[1,2,3]))
        self.gp = tf.reduce_mean(tf.square(self.slopes - 1.))
        
        
        # Wasserstein loss
        self.wd = tf.reduce_mean(self.D_x) - tf.reduce_mean(self.D_f)
        self.L_adv_D = -self.wd + self.gp * self.lambda_gp
        self.L_adv_G = -tf.reduce_mean(self.D_f)
        
        self.L_D_cls = tf.reduce_mean(cross_entropy(labels=self.x_c, logits=self.D_x_cls))
        self.L_G_cls = tf.reduce_mean(cross_entropy(labels=self.target_c, logits=self.D_f_cls))
        self.L_G_recon = l1_loss(self.x, self.G_recon)
                
        self.L_D = self.L_adv_D + self.lambda_cls * self.L_D_cls
        self.L_G = self.L_adv_G + self.lambda_cls * self.L_G_cls + self.lambda_recon * self.L_G_recon
        
        
#         self.L_D_f = tf.reduce_mean(self.D_f)
#         self.L_D_x = -tf.reduce_mean(self.D_x)
#         self.L_D_cls = tf.reduce_mean(cross_entropy(labels=self.x_c, logits=self.D_x_cls)) 
        
#         grad = tf.gradients(ys=self.D_inter, xs=[self.interpolate])[0]
#         grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=3))
#         self.L_D_gp = tf.reduce_mean(tf.squared_difference(grad_l2, tf.ones_like(grad_l2)))
#         self.L_D = self.L_D_x + self.L_D_f + self.lambda_gp * self.L_D_gp + self.lambda_cls * self.L_D_cls
        
#         self.L_G_f = -tf.reduce_mean(self.D_f)
#         self.L_G_recon = l1_loss(self.x, self.G_recon)
#         self.L_G_cls = tf.reduce_mean(cross_entropy(labels=self.target_c, logits=self.D_f_cls))
#         self.L_G = self.L_G_f + self.lambda_cls * self.L_G_cls + self.lambda_recon * self.L_G_recon

        # Variables
        D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "discriminator")
        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator")

        # Optimizer
        self.opt_D = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.L_D, var_list=D_vars)
        self.opt_G = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.L_G, var_list=G_vars)

        # initializer
        self.sess.run(tf.global_variables_initializer())

        # tf saver
        self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep))

        try:
            self.load(self.sess, self.saver, self.ckpt_dir)
        except:
            # save full graph
            self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True)

        # Summary
        if self.flag:
            tf.summary.scalar('loss/d_loss', self.L_D)
            tf.summary.scalar('loss/g_loss', self.L_G)
            tf.summary.scalar('loss/d_loss_cls', self.L_D_cls)
            tf.summary.scalar('loss/g_loss_recon', self.L_G_recon)
            tf.summary.scalar('loss/g_loss_cls', self.L_G_cls)
            tf.summary.scalar('loss/wd', self.wd)

            self.merged = tf.summary.merge_all()
            self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph)

    def train(self, train_flag):
        # load data
        data_path = self.data_dir
        #test_data_path = "./Face_data/Faces_with_expression_label/dataset_64x64" # expressions

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            np.save(data_path + '.npy', data)
            
        print('Shuffle ....')
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]
        print('Shuffle Done')

        # initial parameter
        start_time = time.time()
        self.count = 0

        for epoch in range(self.niter):
            batch_idxs = len(data) // self.batch_size
            lr = np.float32(self.learning_rate)
            
            # learning rate decay
            if epoch >= self.niter / 2.0:
                lr_decay = (self.niter - epoch) / (self.niter / 2.0)
                lr = lr * lr_decay
            
            for idx in range(0, batch_idxs):
                self.count += 1

                # Flip the batch with 0.5 prob to increase training data
                if random.uniform(0, 1) < 0.5:
                    flip = True
                else:
                    flip = False
                
                batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size]
                batch_inputs = [get_image(batch_file[0], flip) for batch_file in batch_files]
                batch_inputs_labels = [get_label(batch_file[0], self.data_label_vector_size) for batch_file in batch_files]
                batch_target_labels = [get_label(batch_file[1], self.data_label_vector_size) for batch_file in batch_files]
                batch_alpha = np.random.uniform(low=0., high=1.0, size=[self.batch_size, 1]).astype(np.float32)
                                
                # feed list 
                D_fetches = [self.opt_D, self.L_D, self.L_D_cls]
                G_fetches = [self.opt_G, self.L_G, self.L_G_cls, self.L_G_recon, self.wd]
                feed_dict = {self.x: batch_inputs, self.x_c: batch_inputs_labels, 
                             self.target_c: batch_target_labels, 
                             self.alpha: batch_alpha,
                             self.lr: lr}
                

                # run tensorflow
                for i in range(5):
                    _, d_loss, d_loss_cls = self.sess.run(D_fetches, feed_dict=feed_dict)
                    
                _, g_loss, g_loss_cls, g_loss_recon, wd = self.sess.run(G_fetches, feed_dict=feed_dict)
                  
                if self.count % 100 == 1:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, "
                          "d_loss: %.6f, d_loss_cls: %.6f, g_loss: %.6f, g_loss_cls: %.6f, g_loss_recon: %.6f, "
                          "wd: %.6f"
                          % (epoch, idx, batch_idxs, time.time() - start_time,
                             d_loss, d_loss_cls, g_loss, g_loss_cls, g_loss_recon, wd))

                # write train summary
                summary = self.sess.run(self.merged, feed_dict=feed_dict)
                self.writer.add_summary(summary, self.count)

                # Test during Training
                if self.count % self.niter_snapshot == (self.niter_snapshot - 1):
                    # save & test
                    self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False)
                    self.test_expr(train_flag)
                    self.test_celebra(train_flag)

    def test_celebra(self, train_flag=True):
        print('Test Sample Generation...')
        # generate output
        img_num = 8*8
        output_f = int(np.sqrt(img_num))
        in_img_num = output_f
        img_size = self.data_size
        gen_img_num = img_num - output_f
        label_size = self.data_label_vector_size
        
        # load data test
        data_path = "./Face_data/Celeba/dataset_64x64" # expressions

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            np.save(data_path + '.npy', data)

        # shuffle test data
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]

        im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3])

        test_files = data[0: output_f]
        test_data = [get_image(test_file) for test_file in test_files]
        test_data = np.repeat(test_data, [label_size]*in_img_num, axis=0)
        test_data_o = [scm.imread(test_file) for test_file in test_files]
        
        # get one-hot labels
        int_labels = list(range(label_size))
        one_hot = np.zeros((label_size, label_size))
        one_hot[np.arange(label_size), int_labels] = 1
        target_labels = np.tile(one_hot, (output_f, 1))
        
        
        output_gen = (self.sess.run(self.G_test, feed_dict={self.x: test_data, 
                                                            self.target_c: target_labels}))  # generator output

        output_gen = [inverse_image(output_gen[i]) for i in range(gen_img_num)]

        for i in range(output_f):
            for j in range(output_f):
                if j == 0:
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_data_o[i]
                else:
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = output_gen[(j-1) + (i * int(output_f-1))]

        # output save
        scm.imsave(self.project_dir + '/result/' + str(self.count) + '_celebra_output.bmp', im_output_gen)
        
    def test_expr(self, train_flag=True):
        print('Train Sample Generation...')
        # generate output
        img_num =  36 #self.batch_size
        display_img_num = int(img_num / 3)
        img_size = self.data_size

        output_f = int(np.sqrt(img_num))
        im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3])
        
        # load data
        data_path = self.data_dir

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            data = pair_expressions(data)
            np.save(data_path + '.npy', data)

        # Test data shuffle
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]
        
        batch_files = data[0: display_img_num]
        test_inputs = [get_image(batch_file[0]) for batch_file in batch_files]
        test_inputs_o = [scm.imread((batch_file[0])) for batch_file in batch_files]
        test_targets = [scm.imread((batch_file[1])) for batch_file in batch_files]
        test_target_labels = [get_label(batch_file[1], self.data_label_vector_size) for batch_file in batch_files]

        output_gen = (self.sess.run(self.G_test, feed_dict={self.x: test_inputs, 
                                                            self.target_c: test_target_labels}))  # generator output

        output_gen = [inverse_image(output_gen[i]) for i in range(display_img_num)]

        for i in range(output_f): # row
            for j in range(output_f): # col
                if j % 3 == 0: # input img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_inputs_o[int(j / 3) + (i * int(output_f / 3))]
                elif j % 3 == 1: # output img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = output_gen[int(j / 3) + (i * int(output_f / 3))]
                else: # target img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_targets[int(j / 3) + (i * int(output_f / 3))]
                   

        labels = np.argmax(test_target_labels, axis=1)
        label_string = ''.join(str(int(l)) for l in labels)
        # output save
        scm.imsave(self.project_dir + '/result/' + str(self.count) + '_' + label_string 
                   + '_expr_output.bmp', im_output_gen)
        

In [None]:

class StarGAN(Operator):
    def __init__(self, sess, project_name):
        Operator.__init__(self, sess, project_name)


    def generator(self, x, c, reuse=None):
        with tf.variable_scope('generator') as scope:
            if reuse:
                scope.reuse_variables()

            f = 64
            image_size = self.data_size
            c_num = self.data_label_vector_size
            p = "SAME"

            x = tf.concat([x, tf.tile(tf.reshape(c, [-1, 1, 1, get_shape_c(c)[-1]]),\
                                      [1, x.get_shape().as_list()[1], x.get_shape().as_list()[2], 1])],\
                          axis=3)
            
            # Down-sampling
            x = conv(x, [7, 7, 3+c_num, f], stride=1, padding=p, name='ds_1')
            x = instance_norm(x, 'in_ds_1')
            x = tf.nn.relu(x)
            x = conv(x, [4, 4, f, f*2], stride=2, padding=p, name='ds_2')
            x = instance_norm(x, 'in_ds_2')
            x = tf.nn.relu(x)
            x = conv(x, [4, 4, f*2, f*4], stride=2, padding=p, name='ds_3')
            x = instance_norm(x, 'in_ds_3')
            x = tf.nn.relu(x)
            
            # Bottleneck
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_1a')
            x_r = instance_norm(x_r, 'in_bneck_1a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_1b')
            x_r = instance_norm(x_r, 'in_bneck_1b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_2a')
            x_r = instance_norm(x_r, 'in_bneck_2a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_2b')
            x_r = instance_norm(x_r, 'in_bneck_2b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_3a')
            x_r = instance_norm(x_r, 'in_bneck_3a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_3b')
            x_r = instance_norm(x_r, 'in_bneck_3b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_4a')
            x_r = instance_norm(x_r, 'in_bneck_4a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_4b')
            x_r = instance_norm(x_r, 'in_bneck_4b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_5a')
            x_r = instance_norm(x_r, 'in_bneck_5a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_5b')
            x_r = instance_norm(x_r, 'in_bneck_5b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_6a')
            x_r = instance_norm(x_r, 'in_bneck_6a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_6b')
            x_r = instance_norm(x_r, 'in_bneck_6b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            # Up-sampling
            x = deconv(x, [4, 4, f*4, f*2], stride=2, padding=p, name='us_1')
            x = instance_norm(x, 'in_us_1')
            x = tf.nn.relu(x)
            x = deconv(x, [4, 4, f*2, f], stride=2, padding=p, name='us_2')
            x = instance_norm(x, 'in_us_2')
            x = tf.nn.relu(x)
            x = conv(x, [7, 7, f, 3], stride=1, padding=p, name='us_3')

            x = tf.nn.tanh(x)
            
        return x
    
    def discriminator(self, x, reuse=None):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()

            f = 64
            f_max = f*8
            image_size = self.data_size
            k_size = int(image_size / np.power(2, 5))            
            c_num = self.data_label_vector_size
            p = "SAME"
            
            x = conv(x, [4, 4, 3, f], stride=2, padding=p, name='conv_1')
            x = lrelu(x)
            x = conv(x, [4, 4, f, f*2], stride=2, padding=p, name='conv_2')
            x = lrelu(x)
            x = conv(x, [4, 4, f*2, f*4], stride=2, padding=p, name='conv_3')
            x = lrelu(x)
            x = conv(x, [4, 4, f*4, f*8], stride=2, padding=p, name='conv_4')
            x = lrelu(x)
            x = conv(x, [4, 4, f*8, f*16], stride=2, padding=p, name='conv_5')
            x = lrelu(x)
            
            if image_size == 128:
                x = conv(x, [4, 4, f*16, f*32], stride=2, padding=p, name='conv_6')
                x = lrelu(x)
                f_max = f_max * 2
                k_size = int(k_size / 2)
                
            out_src = conv(x, [3, 3, f_max, 1], stride=1, padding=p, name='conv_out_src')
            out_cls = conv(x, [k_size, k_size, f_max, c_num], stride=1, name='conv_out_cls')
        
        return out_src, tf.squeeze(out_cls)


In [None]:
import distutils.util
import os
import tensorflow as tf

''' config settings '''

project_name = "StarGAN_Face_2_"
train_flag = True

'''-----------------'''

gpu_number = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #args.gpu_number

with tf.device('/gpu:{0}'.format(gpu_number)):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90)
    config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

    with tf.Session(config=config) as sess:
        model = StarGAN(sess, project_name)

        # TRAIN / TEST
        if train_flag:
            model.train(train_flag)
        else:
            model.test(train_flag)

Shuffle ....
Shuffle Done
Epoch: [ 0] [   0/ 633] time: 1.6636, d_loss: 9.216977, d_loss_cls: 0.677631, g_loss: 6.748394, g_loss_cls: 0.685549, g_loss_recon: 0.603723, wd: 0.238556
Epoch: [ 0] [ 100/ 633] time: 69.5239, d_loss: -3.223270, d_loss_cls: 0.326211, g_loss: -1.857672, g_loss_cls: 0.446668, g_loss_recon: 0.223447, wd: 4.101604
Epoch: [ 0] [ 200/ 633] time: 138.2426, d_loss: -1.324638, d_loss_cls: 0.228790, g_loss: -6.735639, g_loss_cls: 0.446470, g_loss_recon: 0.171162, wd: 1.742344
Epoch: [ 0] [ 300/ 633] time: 208.0725, d_loss: -1.905260, d_loss_cls: 0.233375, g_loss: 1.875639, g_loss_cls: 0.506886, g_loss_recon: 0.178054, wd: 2.684456
Epoch: [ 0] [ 400/ 633] time: 278.3170, d_loss: -2.144655, d_loss_cls: 0.187462, g_loss: -3.451761, g_loss_cls: 0.388351, g_loss_recon: 0.149240, wd: 2.638731
Train Sample Generation...
Test Sample Generation...
Epoch: [ 0] [ 500/ 633] time: 349.7861, d_loss: -1.948005, d_loss_cls: 0.216839, g_loss: -2.537516, g_loss_cls: 0.294575, g_loss_rec

Epoch: [ 7] [ 469/ 633] time: 3489.8782, d_loss: -0.751620, d_loss_cls: 0.000721, g_loss: 0.075063, g_loss_cls: 0.007105, g_loss_recon: 0.065785, wd: 0.812229
Train Sample Generation...
Test Sample Generation...
Epoch: [ 7] [ 569/ 633] time: 3561.6275, d_loss: -0.663467, d_loss_cls: 0.003339, g_loss: 0.573195, g_loss_cls: 0.037032, g_loss_recon: 0.061590, wd: 0.730252
Epoch: [ 8] [  36/ 633] time: 3633.0089, d_loss: -0.778060, d_loss_cls: 0.001247, g_loss: 0.216340, g_loss_cls: 0.022349, g_loss_recon: 0.062511, wd: 0.840270
Epoch: [ 8] [ 136/ 633] time: 3704.3671, d_loss: -0.794903, d_loss_cls: 0.000813, g_loss: 0.236769, g_loss_cls: 0.096726, g_loss_recon: 0.083524, wd: 0.870831
Epoch: [ 8] [ 236/ 633] time: 3775.7089, d_loss: -0.750074, d_loss_cls: 0.001045, g_loss: 0.016070, g_loss_cls: 0.138480, g_loss_recon: 0.064118, wd: 0.828074
Epoch: [ 8] [ 336/ 633] time: 3847.0479, d_loss: -0.796430, d_loss_cls: 0.001051, g_loss: 0.177955, g_loss_cls: 0.065712, g_loss_recon: 0.070919, wd: 0.

Epoch: [15] [ 305/ 633] time: 6988.8046, d_loss: -0.720207, d_loss_cls: 0.000473, g_loss: 0.514406, g_loss_cls: 0.283872, g_loss_recon: 0.055150, wd: 0.794222
Epoch: [15] [ 405/ 633] time: 7060.0945, d_loss: -0.651262, d_loss_cls: 0.000349, g_loss: 0.328554, g_loss_cls: 0.045126, g_loss_recon: 0.056447, wd: 0.700616
Train Sample Generation...
Test Sample Generation...
Epoch: [15] [ 505/ 633] time: 7131.7476, d_loss: -0.655299, d_loss_cls: 0.000094, g_loss: 0.482694, g_loss_cls: 0.020003, g_loss_recon: 0.059862, wd: 0.714029
Epoch: [15] [ 605/ 633] time: 7203.0083, d_loss: -0.764139, d_loss_cls: 0.000436, g_loss: 0.298786, g_loss_cls: 0.027204, g_loss_recon: 0.066728, wd: 0.842988
Epoch: [16] [  72/ 633] time: 7274.3245, d_loss: -0.818930, d_loss_cls: 0.000171, g_loss: 0.543716, g_loss_cls: 0.046466, g_loss_recon: 0.060684, wd: 0.899318
Epoch: [16] [ 172/ 633] time: 7345.6415, d_loss: -0.726933, d_loss_cls: 0.000127, g_loss: 0.648175, g_loss_cls: 0.073371, g_loss_recon: 0.054539, wd: 0.

Epoch: [23] [ 141/ 633] time: 10486.5069, d_loss: -0.554723, d_loss_cls: 0.000145, g_loss: 0.375178, g_loss_cls: 0.096520, g_loss_recon: 0.053114, wd: 0.620772
Epoch: [23] [ 241/ 633] time: 10557.7906, d_loss: -0.691044, d_loss_cls: 0.000276, g_loss: 0.682766, g_loss_cls: 0.055527, g_loss_recon: 0.051551, wd: 0.768960
Epoch: [23] [ 341/ 633] time: 10629.9346, d_loss: -0.704865, d_loss_cls: 0.000071, g_loss: 0.525024, g_loss_cls: 0.057202, g_loss_recon: 0.050633, wd: 0.775781
Train Sample Generation...
Test Sample Generation...
Epoch: [23] [ 441/ 633] time: 10702.3272, d_loss: -0.656609, d_loss_cls: 0.000154, g_loss: 0.964087, g_loss_cls: 0.009092, g_loss_recon: 0.049247, wd: 0.728097
Epoch: [23] [ 541/ 633] time: 10773.5724, d_loss: -0.689732, d_loss_cls: 0.000704, g_loss: 0.630680, g_loss_cls: 0.088717, g_loss_recon: 0.050054, wd: 0.760127
Epoch: [24] [   8/ 633] time: 10845.3375, d_loss: -0.690682, d_loss_cls: 0.000325, g_loss: 0.593533, g_loss_cls: 0.042072, g_loss_recon: 0.050886, 

Epoch: [30] [ 510/ 633] time: 13920.7691, d_loss: -0.818479, d_loss_cls: 0.000125, g_loss: 0.628560, g_loss_cls: 0.066706, g_loss_recon: 0.050236, wd: 0.888302
Epoch: [30] [ 610/ 633] time: 13992.1096, d_loss: -0.640902, d_loss_cls: 0.000200, g_loss: 0.319688, g_loss_cls: 0.102770, g_loss_recon: 0.047963, wd: 0.704227
Epoch: [31] [  77/ 633] time: 14063.4470, d_loss: -0.869649, d_loss_cls: 0.000138, g_loss: 0.570453, g_loss_cls: 0.071619, g_loss_recon: 0.053482, wd: 0.961826
Epoch: [31] [ 177/ 633] time: 14135.4534, d_loss: -0.757420, d_loss_cls: 0.000170, g_loss: 0.695223, g_loss_cls: 0.022591, g_loss_recon: 0.046818, wd: 0.823958
Epoch: [31] [ 277/ 633] time: 14208.0165, d_loss: -0.987414, d_loss_cls: 0.000103, g_loss: 0.569318, g_loss_cls: 0.073372, g_loss_recon: 0.056886, wd: 1.076936
Train Sample Generation...
Test Sample Generation...
Epoch: [31] [ 377/ 633] time: 14279.7202, d_loss: -0.704845, d_loss_cls: 0.000265, g_loss: 0.310036, g_loss_cls: 0.036286, g_loss_recon: 0.047459, 

Epoch: [38] [ 346/ 633] time: 17421.8709, d_loss: -0.626896, d_loss_cls: 0.000165, g_loss: 0.348388, g_loss_cls: 0.002572, g_loss_recon: 0.047634, wd: 0.720900
Train Sample Generation...
Test Sample Generation...
Epoch: [38] [ 446/ 633] time: 17493.5335, d_loss: -0.851633, d_loss_cls: 0.000120, g_loss: 0.720895, g_loss_cls: 0.046895, g_loss_recon: 0.051607, wd: 0.939756
Epoch: [38] [ 546/ 633] time: 17564.8712, d_loss: -0.811227, d_loss_cls: 0.000011, g_loss: 0.148065, g_loss_cls: 0.062119, g_loss_recon: 0.049571, wd: 0.901740
Epoch: [39] [  13/ 633] time: 17636.2308, d_loss: -0.923748, d_loss_cls: 0.000282, g_loss: 0.745920, g_loss_cls: 0.099365, g_loss_recon: 0.047544, wd: 1.017704
Epoch: [39] [ 113/ 633] time: 17707.5503, d_loss: -0.770019, d_loss_cls: 0.000090, g_loss: 0.128398, g_loss_cls: 0.051740, g_loss_recon: 0.047539, wd: 0.833054
Epoch: [39] [ 213/ 633] time: 17778.8850, d_loss: -0.728955, d_loss_cls: 0.000099, g_loss: 0.470209, g_loss_cls: 0.051814, g_loss_recon: 0.044766, 

Epoch: [46] [  82/ 633] time: 20858.9495, d_loss: -0.770226, d_loss_cls: 0.000050, g_loss: 0.475784, g_loss_cls: 0.071999, g_loss_recon: 0.043461, wd: 0.848914
Epoch: [46] [ 182/ 633] time: 20930.3339, d_loss: -0.754061, d_loss_cls: 0.000011, g_loss: 0.436620, g_loss_cls: 0.053821, g_loss_recon: 0.040370, wd: 0.817325
Epoch: [46] [ 282/ 633] time: 21001.6825, d_loss: -0.824370, d_loss_cls: 0.000464, g_loss: 0.834328, g_loss_cls: 0.277228, g_loss_recon: 0.047718, wd: 0.904985
Train Sample Generation...
Test Sample Generation...
Epoch: [46] [ 382/ 633] time: 21073.5030, d_loss: -0.799711, d_loss_cls: 0.000022, g_loss: 0.537857, g_loss_cls: 0.084002, g_loss_recon: 0.045254, wd: 0.904236
Epoch: [46] [ 482/ 633] time: 21144.8601, d_loss: -0.978248, d_loss_cls: 0.000073, g_loss: 0.626482, g_loss_cls: 0.056482, g_loss_recon: 0.042248, wd: 1.072356
Epoch: [46] [ 582/ 633] time: 21216.2235, d_loss: -0.517058, d_loss_cls: 0.000387, g_loss: 0.323595, g_loss_cls: 0.019206, g_loss_recon: 0.040982, 

Epoch: [53] [ 451/ 633] time: 24288.4186, d_loss: -0.728331, d_loss_cls: 0.000161, g_loss: 0.574320, g_loss_cls: 0.112131, g_loss_recon: 0.041865, wd: 0.795331
Epoch: [53] [ 551/ 633] time: 24359.7994, d_loss: -0.923648, d_loss_cls: 0.000029, g_loss: 0.557865, g_loss_cls: 0.014273, g_loss_recon: 0.042842, wd: 0.998853
Epoch: [54] [  18/ 633] time: 24431.1928, d_loss: -0.983073, d_loss_cls: 0.000147, g_loss: 0.334923, g_loss_cls: 0.122911, g_loss_recon: 0.045313, wd: 1.079759
Epoch: [54] [ 118/ 633] time: 24502.5785, d_loss: -0.931143, d_loss_cls: 0.000073, g_loss: 0.285690, g_loss_cls: 0.014162, g_loss_recon: 0.046369, wd: 1.020899
Epoch: [54] [ 218/ 633] time: 24573.9358, d_loss: -0.789939, d_loss_cls: 0.000052, g_loss: 0.563665, g_loss_cls: 0.096254, g_loss_recon: 0.043552, wd: 0.875429
Train Sample Generation...
Test Sample Generation...
Epoch: [54] [ 318/ 633] time: 24645.7182, d_loss: -0.742700, d_loss_cls: 0.000052, g_loss: 0.515248, g_loss_cls: 0.044211, g_loss_recon: 0.043502, 

Epoch: [61] [ 287/ 633] time: 27788.2260, d_loss: -1.130033, d_loss_cls: 0.000036, g_loss: 0.624738, g_loss_cls: 0.110813, g_loss_recon: 0.043272, wd: 1.252769
Train Sample Generation...
Test Sample Generation...
Epoch: [61] [ 387/ 633] time: 27859.9750, d_loss: -0.895096, d_loss_cls: 0.000008, g_loss: 0.506385, g_loss_cls: 0.039572, g_loss_recon: 0.040389, wd: 0.966631
Epoch: [61] [ 487/ 633] time: 27931.3375, d_loss: -1.025282, d_loss_cls: 0.000197, g_loss: 0.579680, g_loss_cls: 0.091016, g_loss_recon: 0.040790, wd: 1.131027
Epoch: [61] [ 587/ 633] time: 28002.6532, d_loss: -0.972321, d_loss_cls: 0.000033, g_loss: 0.540400, g_loss_cls: 0.073192, g_loss_recon: 0.043304, wd: 1.084727
Epoch: [62] [  54/ 633] time: 28074.0011, d_loss: -0.806711, d_loss_cls: 0.000114, g_loss: 0.367350, g_loss_cls: 0.073544, g_loss_recon: 0.040835, wd: 0.893516
Epoch: [62] [ 154/ 633] time: 28145.2786, d_loss: -0.897952, d_loss_cls: 0.000154, g_loss: 0.466573, g_loss_cls: 0.073727, g_loss_recon: 0.039097, 

Epoch: [69] [  23/ 633] time: 31216.5481, d_loss: -0.803981, d_loss_cls: 0.000032, g_loss: 0.074169, g_loss_cls: 0.022433, g_loss_recon: 0.039234, wd: 0.897376
Epoch: [69] [ 123/ 633] time: 31287.9186, d_loss: -0.948729, d_loss_cls: 0.000017, g_loss: 0.599565, g_loss_cls: 0.255035, g_loss_recon: 0.038983, wd: 1.049280
Epoch: [69] [ 223/ 633] time: 31359.2898, d_loss: -0.829658, d_loss_cls: 0.000164, g_loss: 0.732537, g_loss_cls: 0.042386, g_loss_recon: 0.041742, wd: 0.908970
Train Sample Generation...
Test Sample Generation...
Epoch: [69] [ 323/ 633] time: 31431.0681, d_loss: -1.009181, d_loss_cls: 0.000227, g_loss: 0.526793, g_loss_cls: 0.059766, g_loss_recon: 0.037868, wd: 1.091723
Epoch: [69] [ 423/ 633] time: 31502.4073, d_loss: -0.922647, d_loss_cls: 0.000047, g_loss: 0.628157, g_loss_cls: 0.015512, g_loss_recon: 0.039374, wd: 1.006349
Epoch: [69] [ 523/ 633] time: 31573.7735, d_loss: -0.735066, d_loss_cls: 0.000264, g_loss: 0.406183, g_loss_cls: 0.066363, g_loss_recon: 0.037849, 