StarGAN implementation v4

Based on v3, 128x128 data

https://arxiv.org/abs/1711.09020
https://github.com/goldkim92/StarGAN-tensorflow/blob/master/
https://github.com/ly-atdawn/StarGAN-Tensorflow
https://github.com/igul222/improved_wgan_training/blob/master/

In [1]:
import os
import scipy.misc as scm
import numpy as np

def make_project_dir(project_dir):
    if not os.path.exists(project_dir):
        os.makedirs(project_dir)
        os.makedirs(os.path.join(project_dir, 'models'))
        os.makedirs(os.path.join(project_dir, 'result'))
        os.makedirs(os.path.join(project_dir, 'result_test'))


def get_image(img_path, flip=False): # [0,255] to [-1,1]
    img = scm.imread(img_path) 
    if flip:
        img = np.fliplr(img)
    img = img * 2. /255. - 1.
    img = img[..., ::-1]  # rgb to bgr
    return img

def get_label(path, size):
    label = int(path[-5])
    one_hot = np.zeros(size)
    one_hot[ label ] = 1.0
    one_hot[ one_hot==0 ] = 0.0
    return one_hot

def inverse_image(img): # [-1,1] to [0,255]
    img = (img + 1.) / 2. * 255.
    img[img > 255] = 255
    img[img < 0] = 0
    img = img[..., ::-1] # bgr to rgb
    return img

def pair_expressions(paths):
    subject_exprs = []
    subject_pairs = []
    all_pairs = []
    last_subject = 0

    # Pair all expression of a subject
    for path in paths:
        subject = int(path[-10:-6])

        if subject != last_subject and last_subject != 0:
            subject_pairs = [(x, y) for x in subject_exprs for y in subject_exprs]
            all_pairs.extend(subject_pairs)
            subject_exprs = []

        subject_exprs.append(path)
        last_subject = subject

    # Last subject
    subject_pairs = [(x, y) for x in subject_exprs for y in subject_exprs]
    all_pairs.extend(subject_pairs)
    return all_pairs

def get_shape_c(tensor): # static shape
    return tensor.get_shape().as_list()


In [2]:
import tensorflow as tf
import numpy as np

cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits

def conv(x, filter_shape, bias=True, stride=1, padding="VALID", name="conv2d"):
    kw, kh, nin, nout = filter_shape

    stddev = np.sqrt(2.0/(np.sqrt(nin*nout)*kw*kh))
    k_initializer = tf.truncated_normal_initializer(stddev=0.02)
    
    with tf.variable_scope(name):
        x = tf.layers.conv2d(x, filters=nout, kernel_size=(kw, kh), strides=(stride, stride), padding=padding, 
                             use_bias=bias, kernel_initializer=k_initializer)
    return x

def deconv(x, filter_shape, bias=True, stride=1, padding="VALID", name="conv2d_transpose"):
    kw, kh, nin, nout = filter_shape

    stddev = np.sqrt(1.0/(np.sqrt(nin*nout)*kw*kh))
    k_initializer = tf.truncated_normal_initializer(stddev=0.02)
    with tf.variable_scope(name):
        x = tf.layers.conv2d_transpose(x, filters=nout, kernel_size=(kw, kh), strides=(stride, stride), padding=padding, 
                                       use_bias=bias, kernel_initializer=k_initializer)
    return x

def fc(x, output_shape, bias=True, name='fc'):
    shape = x.get_shape().as_list()
    dim = np.prod(shape[1:])
    x = tf.reshape(x, [-1, dim])
    input_shape = dim

    stddev = np.sqrt(1.0/(np.sqrt(input_shape*output_shape)))
    initializer = tf.random_normal_initializer(stddev=stddev)
    with tf.variable_scope(name):
        weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer)
        x = tf.matmul(x, weight)

        if bias:
            b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.))
            x = tf.nn.bias_add(x, b)
    return x


def pool(x, r=2, s=1):
    return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME")

def instance_norm(input, name='instance_norm'):
    with tf.variable_scope(name):
        depth = input.get_shape()[3]
        scale = tf.get_variable('scale', [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
        offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0))
        mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True)
        epsilon = 1e-5
        inv = tf.rsqrt(variance + epsilon)
        normalized = (input-mean)*inv
        return scale*normalized + offset

def l1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))

def l2_loss(x, y):
    return tf.reduce_mean(tf.square(x - y))

def resize_nn(x, size):
    return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size)))

def lrelu(x, leak=0.01, name='lrelu'): #lrelu(x, leak=0.2, name='lrelu'):
    return tf.maximum(x, leak*x)

def gradient_penalty(real, fake, f):
        def interpolate(a, b):
            shape = tf.concat((tf.shape(a)[0:1], tf.tile([1], [a.shape.ndims - 1])), axis=0)
            alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.)
            inter = a + alpha * (b - a)
            inter.set_shape(a.get_shape().as_list())
            return inter

        x = interpolate(real, fake)
        pred, _ = f(x, reuse=True)
        gradients = tf.gradients(pred, x)[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=3))
        gp = tf.reduce_mean((slopes - 1.)**2)
        return gp

In [3]:
import glob
import time
import numpy as np
import tensorflow as tf

class op_base:
    def __init__(self, sess, project_name):
        self.sess = sess

        # Train
        self.flag = True #args.flag
        self.gpu_number = 0 #args.gpu_number
        self.project = project_name #"test_began" #args.project

        # Train Data
        self.data_dir = "./Face_data/Faces_with_expression_label/dataset_128x128" #args.data_dir #./Data
        self.dataset = "expr" #args.dataset  # celeba
        self.data_size = 128 #args.data_size  # 64 or 128
        self.data_opt = "crop" #args.data_opt  # raw or crop
        self.data_label_vector_size = 7 #size of one-hot-encoded label vector

        # Train Iteration
        self.niter = 200 #50 #args.niter
        self.niter_snapshot = 500 #args.nsnapshot
        self.max_to_keep = 200 #args.max_to_keep models

        # Train Parameter
        self.batch_size = 16 #args.batch_size
        self.learning_rate = 1e-4 #args.learning_rate
        self.mm = 0.5 #args.momentum
        self.mm2 = 0.999 #args.momentum2
        self.lamda = 0.001 #args.lamda
        self.gamma = 0.5 #args.gamma
        self.input_size = 128 #args.input_size
        self.embedding = 128 #64 #args.embedding
        
        self.lambda_cls = 1.
        self.lambda_recon = 10.
        self.lambda_gp = 10.

        

        # Result Dir & File
        self.project_dir = 'assets_ae/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size)
        self.ckpt_dir = os.path.join(self.project_dir, 'models')
        self.model_name = "{0}.model".format(self.project)
        self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name)

        # etc.
        if not os.path.exists('assets_ae'):
            os.makedirs('assets_ae')
        make_project_dir(self.project_dir)

    def load(self, sess, saver, ckpt_dir):
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(ckpt_dir, ckpt_name))

In [4]:
import glob
import time
import datetime
import random


class Operator(op_base):
    def __init__(self, sess, project_name):
        op_base.__init__(self, sess, project_name)
        self.build_model()

    def build_model(self):
        # Input placeholder
        self.x = tf.placeholder(tf.float32, shape=[None, self.data_size, self.data_size, 3], name='x')
        self.x_c = tf.placeholder(tf.float32, shape=[None, self.data_label_vector_size], name='x_c')
        self.target = tf.placeholder(tf.float32, shape=[None, self.data_size, self.data_size, 3], name='target')
        self.target_c = tf.placeholder(tf.float32, shape=[None, self.data_label_vector_size], name='target_c')
        self.alpha = tf.placeholder(tf.float32, shape=[None, 1], name='alpha')
        self.lr = tf.placeholder(tf.float32, name='lr')

        # Generator
        self.G_f = self.generator(self.x, self.target_c)
        self.G_recon = self.generator(self.G_f, self.x_c, reuse=True)
        
        self.G_test = self.generator(self.x, self.target_c, reuse=True)
        
        # Discriminator
        self.D_f, self.D_f_cls = self.discriminator(self.G_f)
        self.D_target, self.D_target_cls = self.discriminator(self.target, reuse=True) # discriminate with the target
        
        # Gradient Penalty
        self.real_data = tf.reshape(self.target, [-1, self.data_size*self.data_size*3]) # interpolate with target
        self.fake_data = tf.reshape(self.G_f, [-1, self.data_size*self.data_size*3])
        self.diff = self.fake_data - self.real_data
        self.interpolate = self.real_data + self.alpha*self.diff
        
        self.inter_reshape = tf.reshape(self.interpolate, [-1, self.data_size, self.data_size, 3])
        self.G_inter, _ = self.discriminator(self.inter_reshape, reuse=True)
        
        self.grad = tf.gradients(self.G_inter, 
                                 xs=[self.inter_reshape])[0]
        self.slopes = tf.sqrt(tf.reduce_sum(tf.square(self.grad), axis=[1,2,3]))
        self.gp = tf.reduce_mean(tf.square(self.slopes - 1.))
        
        
        # Wasserstein loss
        self.wd = tf.reduce_mean(self.D_target) - tf.reduce_mean(self.D_f)
        self.L_adv_D = -self.wd + self.gp * self.lambda_gp
        self.L_adv_G = -tf.reduce_mean(self.D_f)
        
        self.L_D_cls = tf.reduce_mean(cross_entropy(labels=self.target_c, logits=self.D_target_cls))# discriminate with the target
        self.L_G_cls = tf.reduce_mean(cross_entropy(labels=self.target_c, logits=self.D_f_cls))
        self.L_G_recon = l1_loss(self.x, self.G_recon)
                
        self.L_D = self.L_adv_D + self.lambda_cls * self.L_D_cls
        self.L_G = self.L_adv_G + self.lambda_cls * self.L_G_cls + self.lambda_recon * self.L_G_recon
        

        # Variables
        D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "discriminator")
        G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "generator")

        # Optimizer
        self.opt_D = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.L_D, var_list=D_vars)
        self.opt_G = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.L_G, var_list=G_vars)

        # initializer
        self.sess.run(tf.global_variables_initializer())

        # tf saver
        self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep))

        try:
            self.load(self.sess, self.saver, self.ckpt_dir)
        except:
            # save full graph
            self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True)

        # Summary
        if self.flag:
            tf.summary.scalar('loss/d_loss', self.L_D)
            tf.summary.scalar('loss/g_loss', self.L_G)
            tf.summary.scalar('loss/d_loss_cls', self.L_D_cls)
            tf.summary.scalar('loss/g_loss_recon', self.L_G_recon)
            tf.summary.scalar('loss/g_loss_cls', self.L_G_cls)
            tf.summary.scalar('loss/wd', self.wd)

            self.merged = tf.summary.merge_all()
            self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph)

    def train(self, train_flag):
        # load data
        data_path = self.data_dir

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            data = pair_expressions(data)
            np.save(data_path + '.npy', data)
            
        print('Shuffle ....')
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]
        print('Shuffle Done')

        # initial parameter
        start_time = time.time()
        self.count = 0

        for epoch in range(self.niter):
            batch_idxs = len(data) // self.batch_size
            lr = np.float32(self.learning_rate)
            
            # learning rate decay
            if epoch >= self.niter / 2.0:
                lr_decay = (self.niter - epoch) / (self.niter / 2.0)
                lr = lr * lr_decay
            
            for idx in range(0, batch_idxs):
                self.count += 1

                # Flip the batch with 0.5 prob to increase training data
                if random.uniform(0, 1) < 0.5:
                    flip = True
                else:
                    flip = False
                
                batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size]
                batch_inputs = [get_image(batch_file[0], flip) for batch_file in batch_files]
                batch_target = [get_image(batch_file[1], flip) for batch_file in batch_files]
                batch_inputs_labels = [get_label(batch_file[0], self.data_label_vector_size) for batch_file in batch_files]
                batch_target_labels = [get_label(batch_file[1], self.data_label_vector_size) for batch_file in batch_files]
                batch_alpha = np.random.uniform(low=0., high=1.0, size=[self.batch_size, 1]).astype(np.float32)
                                
                # feed list 
                D_fetches = [self.opt_D, self.L_D, self.L_D_cls]
                G_fetches = [self.opt_G, self.L_G, self.L_G_cls, self.L_G_recon, self.wd]
                feed_dict = {self.x: batch_inputs, self.x_c: batch_inputs_labels, 
                             self.target: batch_target, self.target_c: batch_target_labels, 
                             self.alpha: batch_alpha,
                             self.lr: lr}
                

                # run tensorflow
                for i in range(5):
                    _, d_loss, d_loss_cls = self.sess.run(D_fetches, feed_dict=feed_dict)
                    
                _, g_loss, g_loss_cls, g_loss_recon, wd = self.sess.run(G_fetches, feed_dict=feed_dict)
                  
                if self.count % 100 == 1:
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, "
                          "d_loss: %.6f, d_loss_cls: %.6f, g_loss: %.6f, g_loss_cls: %.6f, g_loss_recon: %.6f, "
                          "wd: %.6f"
                          % (epoch, idx, batch_idxs, time.time() - start_time,
                             d_loss, d_loss_cls, g_loss, g_loss_cls, g_loss_recon, wd))

                # write train summary
                summary = self.sess.run(self.merged, feed_dict=feed_dict)
                self.writer.add_summary(summary, self.count)

                # Test during Training
                if self.count % self.niter_snapshot == (self.niter_snapshot - 1):
                    # save & test
                    self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False)
                    self.test_expr(train_flag)
                    self.test_celebra(train_flag)

    def test_celebra(self, train_flag=True):
        print('Test Sample Generation...')
        # generate output
        img_num = 8*8
        output_f = int(np.sqrt(img_num))
        in_img_num = output_f
        img_size = self.data_size
        gen_img_num = img_num - output_f
        label_size = self.data_label_vector_size
        
        # load data test
        data_path = "./Face_data/Celeba/dataset_128x128"

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            np.save(data_path + '.npy', data)

        # shuffle test data
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]

        im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3])

        test_files = data[0: output_f]
        test_data = [get_image(test_file) for test_file in test_files]
        test_data = np.repeat(test_data, [label_size]*in_img_num, axis=0)
        test_data_o = [scm.imread(test_file) for test_file in test_files]
        
        # get one-hot labels
        int_labels = list(range(label_size))
        one_hot = np.zeros((label_size, label_size))
        one_hot[np.arange(label_size), int_labels] = 1
        target_labels = np.tile(one_hot, (output_f, 1))
        
        
        output_gen = (self.sess.run(self.G_test, feed_dict={self.x: test_data, 
                                                            self.target_c: target_labels}))  # generator output

        output_gen = [inverse_image(output_gen[i]) for i in range(gen_img_num)]

        for i in range(output_f):
            for j in range(output_f):
                if j == 0:
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_data_o[i]
                else:
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = output_gen[(j-1) + (i * int(output_f-1))]

        # output save
        scm.imsave(self.project_dir + '/result/' + str(self.count) + '_celebra_output.bmp', im_output_gen)
        
    def test_expr(self, train_flag=True):
        print('Train Sample Generation...')
        # generate output
        img_num =  36 #self.batch_size
        display_img_num = int(img_num / 3)
        img_size = self.data_size

        output_f = int(np.sqrt(img_num))
        im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3])
        
        # load data
        data_path = self.data_dir

        if os.path.exists(data_path + '.npy'):
            data = np.load(data_path + '.npy')
        else:
            data = sorted(glob.glob(os.path.join(data_path, "*.*")))
            data = pair_expressions(data)
            np.save(data_path + '.npy', data)

        # Test data shuffle
        random_order = np.random.permutation(len(data))
        data = [data[i] for i in random_order[:]]
        
        batch_files = data[0: display_img_num]
        test_inputs = [get_image(batch_file[0]) for batch_file in batch_files]
        test_inputs_o = [scm.imread((batch_file[0])) for batch_file in batch_files]
        test_targets = [scm.imread((batch_file[1])) for batch_file in batch_files]
        test_target_labels = [get_label(batch_file[1], self.data_label_vector_size) for batch_file in batch_files]

        output_gen = (self.sess.run(self.G_test, feed_dict={self.x: test_inputs, 
                                                            self.target_c: test_target_labels}))  # generator output

        output_gen = [inverse_image(output_gen[i]) for i in range(display_img_num)]

        for i in range(output_f): # row
            for j in range(output_f): # col
                if j % 3 == 0: # input img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_inputs_o[int(j / 3) + (i * int(output_f / 3))]
                elif j % 3 == 1: # output img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = output_gen[int(j / 3) + (i * int(output_f / 3))]
                else: # target img
                    im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
                        = test_targets[int(j / 3) + (i * int(output_f / 3))]
                   

        labels = np.argmax(test_target_labels, axis=1)
        label_string = ''.join(str(int(l)) for l in labels)
        # output save
        scm.imsave(self.project_dir + '/result/' + str(self.count) + '_' + label_string 
                   + '_expr_output.bmp', im_output_gen)
        

In [5]:

class StarGAN(Operator):
    def __init__(self, sess, project_name):
        Operator.__init__(self, sess, project_name)


    def generator(self, x, c, reuse=None):
        with tf.variable_scope('generator') as scope:
            if reuse:
                scope.reuse_variables()

            f = 64
            image_size = self.data_size
            c_num = self.data_label_vector_size
            p = "SAME"

            x = tf.concat([x, tf.tile(tf.reshape(c, [-1, 1, 1, get_shape_c(c)[-1]]),\
                                      [1, x.get_shape().as_list()[1], x.get_shape().as_list()[2], 1])],\
                          axis=3)
            
            # Down-sampling
            x = conv(x, [7, 7, 3+c_num, f], stride=1, padding=p, name='ds_1')
            x = instance_norm(x, 'in_ds_1')
            x = tf.nn.relu(x)
            x = conv(x, [4, 4, f, f*2], stride=2, padding=p, name='ds_2')
            x = instance_norm(x, 'in_ds_2')
            x = tf.nn.relu(x)
            x = conv(x, [4, 4, f*2, f*4], stride=2, padding=p, name='ds_3')
            x = instance_norm(x, 'in_ds_3')
            x = tf.nn.relu(x)
            
            # Bottleneck
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_1a')
            x_r = instance_norm(x_r, 'in_bneck_1a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_1b')
            x_r = instance_norm(x_r, 'in_bneck_1b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_2a')
            x_r = instance_norm(x_r, 'in_bneck_2a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_2b')
            x_r = instance_norm(x_r, 'in_bneck_2b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_3a')
            x_r = instance_norm(x_r, 'in_bneck_3a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_3b')
            x_r = instance_norm(x_r, 'in_bneck_3b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_4a')
            x_r = instance_norm(x_r, 'in_bneck_4a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_4b')
            x_r = instance_norm(x_r, 'in_bneck_4b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_5a')
            x_r = instance_norm(x_r, 'in_bneck_5a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_5b')
            x_r = instance_norm(x_r, 'in_bneck_5b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            x_r = conv(x, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_6a')
            x_r = instance_norm(x_r, 'in_bneck_6a')
            x_r = tf.nn.relu(x_r)
            x_r = conv(x_r, [3, 3, f*4, f*4], stride=1, padding=p, name='bneck_6b')
            x_r = instance_norm(x_r, 'in_bneck_6b')
            x = x + x_r
            x = tf.nn.relu(x)
            
            # Up-sampling
            x = deconv(x, [4, 4, f*4, f*2], stride=2, padding=p, name='us_1')
            x = instance_norm(x, 'in_us_1')
            x = tf.nn.relu(x)
            x = deconv(x, [4, 4, f*2, f], stride=2, padding=p, name='us_2')
            x = instance_norm(x, 'in_us_2')
            x = tf.nn.relu(x)
            x = conv(x, [7, 7, f, 3], stride=1, padding=p, name='us_3')

            x = tf.nn.tanh(x)
            
        return x
    
    def discriminator(self, x, reuse=None):
        with tf.variable_scope('discriminator') as scope:
            if reuse:
                scope.reuse_variables()

            f = 64
            f_max = f*8
            image_size = self.data_size
            k_size = int(image_size / np.power(2, 5))            
            c_num = self.data_label_vector_size
            p = "SAME"
            
            x = conv(x, [4, 4, 3, f], stride=2, padding=p, name='conv_1')
            x = lrelu(x)
            x = conv(x, [4, 4, f, f*2], stride=2, padding=p, name='conv_2')
            x = lrelu(x)
            x = conv(x, [4, 4, f*2, f*4], stride=2, padding=p, name='conv_3')
            x = lrelu(x)
            x = conv(x, [4, 4, f*4, f*8], stride=2, padding=p, name='conv_4')
            x = lrelu(x)
            x = conv(x, [4, 4, f*8, f*16], stride=2, padding=p, name='conv_5')
            x = lrelu(x)
            
            if image_size == 128:
                x = conv(x, [4, 4, f*16, f*32], stride=2, padding=p, name='conv_6')
                x = lrelu(x)
                f_max = f_max * 2
                k_size = int(k_size / 2)
                
            out_src = conv(x, [3, 3, f_max, 1], stride=1, padding=p, name='conv_out_src')
            out_cls = conv(x, [k_size, k_size, f_max, c_num], stride=1, name='conv_out_cls')
        
        return out_src, tf.squeeze(out_cls)


In [6]:
import distutils.util
import os
import tensorflow as tf

''' config settings '''

project_name = "StarGAN_Face_1_"
train_flag = True

'''-----------------'''

gpu_number = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" #args.gpu_number

with tf.device('/gpu:{0}'.format(gpu_number)):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90)
    config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

    with tf.Session(config=config) as sess:
        model = StarGAN(sess, project_name)

        # TRAIN / TEST
        if train_flag:
            model.train(train_flag)
        else:
            model.test(train_flag)

Shuffle ....
Shuffle Done
Epoch: [ 0] [   0/1813] time: 4.2972, d_loss: 5.693913, d_loss_cls: 0.638042, g_loss: 9.709621, g_loss_cls: 0.580023, g_loss_recon: 0.655501, wd: 3.102028
Epoch: [ 0] [ 100/1813] time: 316.2945, d_loss: -10.553990, d_loss_cls: 0.155376, g_loss: 1.918552, g_loss_cls: 0.529656, g_loss_recon: 0.236782, wd: 13.315695
Epoch: [ 0] [ 200/1813] time: 632.4336, d_loss: -4.260210, d_loss_cls: 0.390434, g_loss: -8.949572, g_loss_cls: 0.673184, g_loss_recon: 0.175317, wd: 5.793550
Epoch: [ 0] [ 300/1813] time: 949.1301, d_loss: -3.972862, d_loss_cls: 0.124830, g_loss: -7.589349, g_loss_cls: 0.718749, g_loss_recon: 0.158207, wd: 5.206684
Epoch: [ 0] [ 400/1813] time: 1266.0303, d_loss: -4.077964, d_loss_cls: 0.187572, g_loss: 16.166456, g_loss_cls: 0.612723, g_loss_recon: 0.152263, wd: 5.190974
Train Sample Generation...
Test Sample Generation...
Epoch: [ 0] [ 500/1813] time: 1584.8919, d_loss: -7.848335, d_loss_cls: 0.129742, g_loss: 4.666964, g_loss_cls: 1.160315, g_loss

KeyboardInterrupt: 