In [None]:
from PIL import Image
import numpy as np
import pandas as pd
import glob
from random import randint, shuffle
import matplotlib.pyplot as plt
import datetime
import os
import tensorflow as tf
from functools import partial

#### Keras APIs
from keras.models import Sequential, Model,load_model
from keras.layers.merge import _Merge
from keras.layers import Layer,InputLayer, Input,Reshape, Conv2D, Conv2DTranspose,Embedding, CuDNNGRU,Bidirectional,\
Dense, Flatten,BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling2D,MaxPooling2D,Dropout,Concatenate,\
Lambda
from keras import layers
### pip install git+https://www.github.com/keras-team/keras-contrib.git
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization, InputSpec
from keras.optimizers import Adam,RMSprop,Adadelta,SGD
import keras.backend as K
from keras.preprocessing import text, sequence
from keras.initializers import RandomNormal

In [None]:
class DataLoader():
    def __init__(self,img_res,dataset_path):
        self.img_res = img_res
        self.dataset_path = dataset_path
        self.df_labels= pd.read_table('list_attr_celeba.txt',skiprows=1,delim_whitespace=True)
        self.df_labels.replace(-1,0,inplace=True)
        self.path = glob.glob('%s\\*' % (self.dataset_path))
    
    #cut the path E:\machine_learning_image_data\CelebA\Img\img_align_celeba\000001.jpg to 000001.jpg
    def cut(self,path_list):
        return path_list[-10:]
    
    #load 1 batch of images not iterate
    def load_data(self,batch_size=1):

                
        # names of the batch of images
        batch_images = np.random.choice(self.path,size=batch_size)
        img_batch_names =[self.cut(p) for p in batch_images]
        
        # get original labels from labels dataframe
        origin_path_batch = np.array(img_batch_names)        
        origin_label = self.df_labels.loc[origin_path_batch]
        origin_label = origin_label.values          
        
        imgs=[]
#         imgs_batch = np.empty(shape = [0,128,128,3])
        for img_path in batch_images:
            try:
                img = Image.open(img_path).convert('RGB')
                # crop the 178x218 image to 178x178 then resize to 128x128. crop the middle part.                
                img = img.crop((0,20,178,198))
                img = img.resize(self.img_res,Image.BILINEAR)
                # 50% chance to flip the image for data enhancement
                if np.random.random()>0.5: 
                    img = np.fliplr(img)
                # convert 0-255 to -1 - +1 domain   
                img = (np.array(img)/255 -0.5) *2
                imgs.append(img)
#                 imgs_batch=np.append(imgs_batch,img,axis = 0) 
            except Exception as e:
                pass 
           
        

        imgs = np.array(imgs)
        imgs =imgs.astype(np.float)
        return imgs,origin_label
    
    # load batch of images and iterate
    def load_batch(self, batch_size = 16):
        
               
        # the number of batch in 1 epoch
        self.num_batches = int(len(self.path) / batch_size)
        self.total_samples = self.num_batches * batch_size
        #set replace = False in order to sample the sample again. 
        path_all_images = np.random.choice(self.path,self.total_samples,replace = False)
        
        for i in range(self.num_batches - 1):
            #path_batch are the image files names. need this value to get correct label info
            path_batch = path_all_images[i * batch_size: (i+1) * batch_size]
            img_batch_names =[self.cut(p) for p in path_batch]
            
            imgs_batch =[]
#             imgs_batch = np.empty(shape = [0,128,128,3])
            
            for img_path in path_batch:
                try:
                    img = Image.open(img_path).convert('RGB')
                    # crop the 178x218 image to 178x178 then resize to 128x128. crop the middle part.
                    img = img.crop((0,20,178,198))
                    img = img.resize(self.img_res,Image.BILINEAR)
                    # 50% chance to flip the image for data enhancement
                    if np.random.random()>0.5: 
                        img = np.fliplr(img)
                    # convert 0-255 to -1 - +1 domain   
                    img = (np.array(img)/255 -0.5) *2
#                     imgs_batch=np.append(imgs_batch,img,axis = 0)
                    imgs_batch.append(img)
                except Exception as e:
                    pass
                                  
            imgs_batch = np.array(imgs_batch)
            imgs_batch = imgs_batch.astype(np.float)

            
            
            # get original labels from labels dataframe
            origin_path_batch = np.array(img_batch_names)
            origin_label = self.df_labels.loc[img_batch_names]
            origin_label = origin_label.values
            
            # get target labels by permuation of original label
            target_path_batch = np.random.permutation(origin_path_batch)
            target_label = self.df_labels.loc[target_path_batch]
            target_label = target_label.values
            
            yield imgs_batch, origin_label,target_label                 

In [None]:
# the function for WGAN-GP to calculate the randomweighted average of two images(image_real and image_fake)
class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between real and generated image samples"""
#     def __init__(self, batch_size=16):
#         self.batch_size = batch_size
    
    def _merge_function(self,inputs):
        alpha = K.random_uniform((16, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [None]:
class StarGAN():
    
    def __init__(self,continue_training=True,linear_decay=False):
        
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.gen_f =64
        self.disc_f =64
        self.linear_decay = linear_decay
        
        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**6 )
        self.disc_patch = (patch, patch, 1)
        
        self.learning_rate_initial = 0.0001
        self.learning_rate = self.learning_rate_initial
        
        optimizer_gen = Adam(self.learning_rate,0.5,0.999)
        optimizer_disc = Adam(self.learning_rate,0.5,0.999)
        
        dataset_path='E:\\machine_learning_image_data\\CelebA\\Img\\img_align_celeba'
        #number of domain
        self.label_dim = 40 
        
        ### Build and Compile Discriminator model
        
        
        self.D = self.build_discriminator()

        
        if continue_training:            
            self.D.load_weights("models\\stargan_discriminator_weights-v1.h5")
            
            
        img_real = Input(shape = self.img_shape)
        img_generated = Input(shape = self.img_shape)
        
        valid,valid_origin_label = self.D(img_real)
        fake, _ = self.D(img_generated)
        
        # Construct weighted average between real and fake images
        random_weighted_average = RandomWeightedAverage(batch_size=16)
        img_interpolated = random_weighted_average([img_real,img_generated])

        
        
        # Determine validity of weighted sample
        validity_interpolated, _ = self.D(img_interpolated)

        
        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(self.gradient_penalty_loss,
                          averaged_samples=img_interpolated)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
        
        

        
        #turn the weights for label 5 times for investigation
        self.D_model = Model([img_real,img_generated], [valid,fake,validity_interpolated,valid_origin_label])
        self.D_model.compile(loss=[self.wasserstein_loss,self.wasserstein_loss,partial_gp_loss,self.classification_loss],\
                                        loss_weights = [1,1,10,5],\
                                        optimizer = optimizer_disc)
        
        
        ### Build and Compile Generator model
        
        self.G = self.build_generator()

        if continue_training:            
            self.G.load_weights("models\\stargan_generator_weights-v1.h5")
            
        origin_label = Input(shape=(self.label_dim,))
        target_label = Input(shape=(self.label_dim,))
        
        img_generated_G = self.G([img_real,target_label])
        img_reconstruct = self.G([img_generated_G,origin_label])
        
        valid_G,valid_target_label = self.D(img_generated_G)

        self.D.trainable = False   
        self.G.trainable = True
        

        
        self.G_model = Model([img_real,target_label,origin_label], [valid_G,valid_target_label,img_reconstruct])
        self.G_model.compile(loss = [self.wasserstein_loss,self.classification_loss,'mae'],\
                                       loss_weights =[1,5,10],optimizer = optimizer_gen)
        
        
        self.data_loader = DataLoader(img_res=(self.img_rows,self.img_cols), dataset_path = dataset_path)

    
    def classification_loss(self,y,y_pred):
        return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y,logits=y_pred))
        
    def wasserstein_loss(self,y,y_pred):
            return K.mean(y*y_pred)  
        
    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)
        
    def build_generator(self):
        
        def depth_wise_concatenate(input_img, input_label):
            # Replicate spatially and concatenate domain information
            label = Lambda(lambda x: K.repeat(x, self.img_rows**2))(input_label)
            label = Reshape((self.img_rows, self.img_cols, self.label_dim))(label)
            x = Concatenate()([input_img, label])  
            return x 

        
        def conv2d(layer_input, filters, f_size=4,strides=2,padding = 'valid'):
            
            init = RandomNormal(stddev=0.02)
            d = ZeroPadding2D((1,1))(layer_input)
            d = Conv2D(filters, kernel_size=f_size, strides=strides, padding=padding,kernel_initializer=init)(d)            
            d = InstanceNormalization(axis=-1)(d)
            d = Activation('relu')(d)
            return d
        
        def residual(layer_input,filters):
            init = RandomNormal(stddev=0.02)
            # first layer

            x = ZeroPadding2D((1,1))(layer_input)
            x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='valid',kernel_initializer=init)(x)
            x = InstanceNormalization(axis=-1)(x)
            x = Activation('relu')(x)
            # second layer

            x = ZeroPadding2D((1,1))(x)
            x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='valid',kernel_initializer=init)(x)
            x = InstanceNormalization(axis=-1)(x)
            # merge
            x = layers.add([x, layer_input])
            return x

        def deconv2d(layer_input, filters, f_size=4,padding = 'same'):
            init = RandomNormal(stddev=0.02)

            u = Conv2DTranspose(filters, kernel_size=f_size, strides=2, padding=padding,kernel_initializer=init)(layer_input)
            u = InstanceNormalization(axis=-1)(u)
            u = Activation('relu')(u)               
            return u
        
        input_img = Input(shape=self.img_shape)
        input_label = Input(shape=(self.label_dim,))
        
        model = depth_wise_concatenate(input_img, input_label)        

        model = ZeroPadding2D((3,3))(model)
        model = conv2d(model,self.gen_f,f_size=7,strides=1)
        model = conv2d(model,self.gen_f*2)
        model = conv2d(model,self.gen_f*4)
        
        #6 residual blocks
        for _ in range(6):
            model = residual(model,self.gen_f*4)
        
        model = deconv2d(model,self.gen_f*2)
        model = deconv2d(model,self.gen_f)     

        model = ZeroPadding2D((3,3))(model)
        model = Conv2D(filters = 3,kernel_size=7,strides=1,padding='valid')(model)
        
        model =InstanceNormalization(axis=-1)(model)
        
        output_img = Activation('tanh')(model)
        output_model = Model(inputs=[input_img,input_label],outputs=output_img)
        output_model.summary()
        
        return output_model      
        
    def build_discriminator(self):
        
        def disc_layer(layer_input, filters, f_size=4, strides=2,normalization=False):
            
            init = RandomNormal(stddev=0.02)

            d_layer = ZeroPadding2D((1,1))(layer_input)
            d_layer = Conv2D(filters, kernel_size=f_size, strides=strides, padding='valid',kernel_initializer=init)(d_layer)
            if normalization:                
                d_layer = InstanceNormalization(axis=-1)(d_layer)
            d_layer = LeakyReLU(alpha=0.01)(d_layer)
            return d_layer
        
        input_img = Input(shape=self.img_shape)
        
        model = disc_layer(input_img, self.disc_f, normalization=False)
        
        model = disc_layer(model, self.disc_f*2)
        model = disc_layer(model, self.disc_f*4)
        model = disc_layer(model, self.disc_f*8)
        model = disc_layer(model, self.disc_f*16)
        model = disc_layer(model, self.disc_f*32)        
        
        

        output_d_src = ZeroPadding2D((1,1))(model)
        # D_src (2,2,1)
        output_d_src = Conv2D(1, kernel_size=3, strides=1, padding='valid')(output_d_src)
        
        # D_cls (1,1,self.label_dim) -> (self.label_dim,)
        model = Conv2D(filters = self.label_dim,kernel_size = int(self.img_rows/64), strides = 1,padding = 'valid')(model)      

        
        output_d_cls = Reshape((self.label_dim,))(model)  

        output_model = Model(input_img,[output_d_src,output_d_cls])
        output_model.summary()
        
        return output_model
    
    def train(self,epochs,batch_size=16,sample_interval=500,resume_epoch=0,resume_batch=0):
        
        start_time = datetime.datetime.now()
        
        # Adversarial loss ground truths # (-1,2,2,1)
        valid = - np.ones((batch_size,) + self.disc_patch)
        fake = np.ones((batch_size,) + self.disc_patch)
        dummy = np.zeros((batch_size,) + self.disc_patch) # for gradient panalty of WGAN-GP
        
        
        for epoch in range(resume_epoch,epochs):

            # linear decay of learning rate for the last 100 epoch
            if self.linear_decay ==True:
                self.learning_rate = self.lr_linear_decay(epoch,epochs) 
            
            for batch_i, (imgs, origin_label, target_label) in enumerate(self.data_loader.load_batch(batch_size),start=resume_batch):
                
                # in case of resume from resume_batch,the stop point needs to be set, otherwise it iterates over the max_batch_num.
                if batch_i> self.data_loader.num_batches:   
                    #reset the resume_batch parameter so that for the next epoch the iteration will start from 0
                    resume_batch = 0
                    break

                                                           
                # ----------------------
                #  Train Discriminator
                # ----------------------
                
                
                imgs_generated = self.G.predict([imgs,target_label])
                
                d_loss = self.D_model.train_on_batch(x = [imgs,imgs_generated], y = [valid,fake,dummy,origin_label])
                
                
                # ----------------------
                #  Train Generator
                # ----------------------                
                
                
                imgs_reconstruct =self.G.predict([imgs_generated,origin_label])
                g_loss = self.G_model.train_on_batch(x = [imgs,target_label,origin_label], y=[valid,target_label,imgs])
                
                elapsed_time = datetime.datetime.now() - start_time
                
                print ("[epoch %d/%d], [batch %d/%d], [Dloss: %f] [G loss: %05f] time: %s"\

                       % (epoch, epochs, batch_i, self.data_loader.num_batches,d_loss[0], g_loss[0], elapsed_time))
                print ("current learning rate:%05f" %self.learning_rate)
                
                if batch_i % sample_interval == 0:
                    
                    
#                     print(f"D_model metrics:{self.D_model.metrics_names}")
#                     print(f"G_model metrics:{self.G_model.metrics_names}")
                    print ("[epoch %d/%d], [batch %d/%d], [Dloss: %f] [G loss: %05f] time: %s"\

                           % (epoch, epochs, batch_i, self.data_loader.num_batches,d_loss[0], g_loss[0], elapsed_time))
                    print ("current learning rate:%05f" %self.learning_rate)
                    
                    self.sample_images(epoch, batch_i)
#                     self.image_debug(imgs,origin_label,target_label)
                    
                    
            # save the generator model each interval of 2000

                if batch_i % 2000 == 0:

                    os.makedirs('models', exist_ok=True)

                    #save models weights
                    self.D.save_weights('models\\stargan_discriminator_weights-v1.h5')
                    self.G.save_weights('models\\stargan_generator_weights-v1.h5')
                    
                    #save models
#                     self.D.save('models\\stargan_discriminator-v1_CelebA_40dims.h5')
#                     self.G.save('models\\stargan_generator-v1_CelebA_40dims.h5')              
    
    def sample_images(self,epoch,batch_i):
        
        os.makedirs('images/',exist_ok = True)
        r, c = 5,6
    
#         target_label = np.zeros((5,40))
        
#         # 9: Blond_Hair;15:Eyeglass; 22: Mustache; 31:smiling; 39:Young;
        

#         # it seems if only one feature is enabled and others are unabled, the machine doesn't understand.
#         # need further investication.

#         str_0 = "0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"
# #         str_0 = "0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 0 1"
#         target_label[0] = np.array(str_0.split())

#         str_1 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"
# #         str_1 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 0 1 0 0 0 0 0 0"
#         target_label[1] = np.array(str_1.split())

#         str_2 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0" #mastache with male 
# #         str_2 = "1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0"
#         target_label[2] = np.array(str_2.split())
        
#         str_3 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0" 
# #         str_3 = "0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 1 0 0 1"
#         target_label[3] = np.array(str_3.split())

#         str_4 = "0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1" #young with charming
# #         str_4 = "0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1"
#         target_label[4] = np.array(str_4.split())
        

                
        #1st photo is original photo
        
        # 5 x 6 photos, first column is origin photo, other columns are generated photos

        imgs_generated_illustration = np.empty(shape=[0,128,128,3])

        #5 iterations to get 5 set of photos. imgs_generated_illustration_all is (30,128,128,3)
        for k in range(5):
            
            
            # load 1 images
            imgs,imgs_origin_label = self.data_loader.load_data(batch_size=1)
            
            # repeat the origin_label x 5, (40,) -> (5,40)
            imgs_origin_label = np.repeat(imgs_origin_label,5,axis=0)
            
            #try to tune 1 parameter of the original label and make it as target label
            target_label = imgs_origin_label
            
            target_label[0,9] = 1 
            target_label[0,8] = 0 #blond without black
            target_label[0,10] = 0 #blond without brawn
            target_label[0,17] = 0 #blond without gray
            target_label[1,15] = 1
            target_label[2,22] = 1  
            target_label[2,20] = 1 #mastache with male
            target_label[2,24] = 0 #mastache without no-beard
            target_label[3,31] = 1
            target_label[4,39] = 1
            target_label[4,2] = 1 #young with attractive
            target_label[4,3] = 0 #young without eyebag

            
            
            
            # 1 image generate 5 different faces by 5 target labels   
            imgs_gen = self.G.predict([np.repeat(imgs,5,axis=0),target_label])
            # 6 images
            imgs_concat = np.concatenate([imgs,imgs_gen])
            # 6 images x 5
            imgs_generated_illustration =np.append(imgs_generated_illustration,imgs_concat,axis=0)
                    
        
        # Rescale images 0 - 1
        imgs_generated_illustration = 0.5 * imgs_generated_illustration + 0.5
        
        label_title = ['Input','Blond_Hair','Eyeglasses','Mustache','Smiling','Young']
        
        fig, axs = plt.subplots(r,c)
        # 20 x 15 is about 1440 x 1080 pixels
        fig.set_size_inches(20, 15, forward=True)
        
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(imgs_generated_illustration[cnt,:,:,:])
                if i == 0:                    
                    axs[i,j].set_title(label_title[j],fontsize = 25)
                    
                axs[i,j].axis('off')
                cnt += 1 
        
        fig.savefig(f"images/Sample_Epoch_No_{epoch}_Batch_No_{batch_i}.png")
        plt.close()        
    
    
    def image_debug(self,imgs,origin_label,target_label):
        
        generated_imgs = self.G.predict([imgs,target_label])
        imgs = 0.5 * imgs + 0.5
        generated_imgs = 0.5 * generated_imgs + 0.5
        
        plt.imshow(imgs[0])
        plt.show()
        plt.close()
        
        plt.imshow(generated_imgs[0])
        plt.show()
        plt.close()
        
        
#         print('origin_label: ',origin_label[0])
#         print('target_label: ',target_label[0])

    # Generate image by personal image
    def generate_image(self):
        
        
        print('sample generation started...')
        imgs_folder = 'samples'
        #make dir named imgs_folder
        os.makedirs(imgs_folder, exist_ok=True)
        # read images into np array
        imgs_names = self.load_data('{}/*.jpeg'.format\
                                     (imgs_folder))
        imgs = [self.read_image(imgs_names[j]) for j in range(len(imgs_names))]
        imgs = np.array(imgs)

        r, c = 5,6
    
        target_label = np.zeros((5,40))
        
        str_0 = "0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"
#         str_0 = "0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 0 1"
        target_label[0] = np.array(str_0.split())

        str_1 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0"
#         str_1 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 0 1 0 0 0 0 0 0"
        target_label[1] = np.array(str_1.split())

        str_2 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0" #mastache with male 
#         str_2 = "1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0"
        target_label[2] = np.array(str_2.split())
        
        str_3 = "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0" 
#         str_3 = "0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 1 0 0 1"
        target_label[3] = np.array(str_3.split())

        str_4 = "0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1" #young with charming
#         str_4 = "0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1"
        target_label[4] = np.array(str_4.split())        

        # 5 x 6 photos, first column is origin photo, other columns are generated photos

        imgs_generated_illustration = np.empty(shape=[0,128,128,3])
                        
        for k in range(5):
            
            
            imgs_gen =  self.G.predict([np.repeat(imgs[k],5,axis=0),target_label])
            
                        
            # 6 images
            imgs_concat = np.concatenate([imgs[k],imgs_gen])
            # 6 images x 5
            imgs_generated_illustration =np.append(imgs_generated_illustration,imgs_concat,axis=0)

            
         # Rescale images 0 - 1
        imgs_generated_illustration = 0.5 * imgs_generated_illustration + 0.5
        imgs_generated_illustration = np.clip(imgs_generated_illustration, 0.0, 1.0)
        
        
        
        
        label_title = ['Input','Blond_Hair','Eyeglasses','Mustache','Smiling','Young']
        
        fig, axs = plt.subplots(r,c)
        # 20 x 15 is about 1440 x 1080 pixels
        fig.set_size_inches(20, 15, forward=True)
        
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(imgs_generated_illustration[cnt,:,:,:])
                if i == 0:                    
                    axs[i,j].set_title(label_title[j],fontsize = 25)
                    
                axs[i,j].axis('off')
                cnt += 1 

        fig.savefig(f"samples/generated_image.png")
        plt.close()    
        print('...sample generation finished')        





    #calculate the lr based on the current epoch number. It's near linear but not pure linear
    def lr_linear_decay(self,epoch,epochs):
        
        lr = self.learning_rate_initial
        
        remaining_epoch = epochs - epoch
#         remaining_batch = batches - batch_i

        lr = self.learning_rate_initial * (remaining_epoch /epochs ) 
        
        return lr
        
        
    

In [None]:
if __name__ =='__main__':    
    #set continue_training to True to load the weights,otherwise it will trains from scratch
    stargan = StarGAN(continue_training=True,linear_decay=True)
    #resume_epoch & resume_batch: set in case of resume of the training stopped
    stargan.train(epochs=11,resume_epoch=6,resume_batch=8000)

In [None]:
#4,10000

In [None]:
# 1 - 40 labels

# 5_o_Clock_Shadow  Arched_Eyebrows Attractive      Bags_Under_Eyes    Bald
# Bangs             Big_Lips        Big_Nose        Black_Hair         Blond_Hair
# Blurry            Brown_Hair      Bushy_Eyebrows  Chubby             Double_Chin
# Eyeglasses        Goatee          Gray_Hair       Heavy_Makeup       High_Cheekbones
# Male              Mouth_Slightly_Open Mustache    Narrow_Eyes        No_Beard 
# Oval_Face         Pale_Skin       Pointy_Nose     Receding_Hairline  Rosy_Cheeks
# Sideburns         Smiling         Straight_Hair   Wavy_Hair          Wearing_Earrings
# Wearing_Hat       Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young 

In [None]:
# df_labels.columns.get_loc('Young')

In [None]:
# [0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 0 1] blond
# [0 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0 1 1 0 0 0 1 1 0 1] blond

# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 1 1 1 0 0 0 1 0 1 0 0 0 0 0 0] eyeglass

# [1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0] mustache

#[0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 1 0 0 1] smile

# [0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1] young

In [None]:
#blondy: 476  [0 1 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 1 0 0 1]
#glass: 676   [0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 0 1 0 0 1 1 0 0]
#mustache:689 [1 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 1 1 0 1 0 0 0 0 0 1]
#smile:718    [0 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 1 0 0 0 0 1 0 1 0 0 1 0 1 1 0 1]
# young :459  [0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]