In [None]:
# Large amount of credit goes to:
# https://github.com/eriklindernoren/Keras-GAN/blob/master/cyclegan/cyclegan.py
# https://github.com/tjwei/GANotebooks/blob/master/CycleGAN-keras.ipynb

# tensorflow version:1.14. >2.0 cannot be used for the code below 
# because of the compatibility issues(e.g. InstanceNormalization,InputSpec etc.) 
# if you insists to use tensorflow >2.0, the codes for InstanceNormaization and ReflectionPadding must be modified.

In [None]:
from PIL import Image
import numpy as np
import pandas as pd
import glob
from random import randint, shuffle
import matplotlib.pyplot as plt
import datetime
import os
import tensorflow as tf


#### Keras APIs
from keras.models import Model,load_model
from keras.layers import Layer,InputLayer, Input,Reshape, Conv2D, Conv2DTranspose,\
Dense, Flatten,BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling2D,MaxPooling2D,Dropout,Concatenate
from keras import layers
### pip install git+https://www.github.com/keras-team/keras-contrib.git
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization, InputSpec
from keras.optimizers import Adam,RMSprop,Adadelta,SGD
import keras.backend as K
from keras.initializers import RandomNormal

In [None]:
class CycleGAN():
    def __init__(self,continue_training=True,linear_decay=False):
        self.load_size = 256
        if self.load_size == 256:
            self.num_residual = 9
        else:
            self.num_residual = 6
        self.img_shape = (self.load_size,self.load_size,3)
        self.dataset_name = "horse2zebra"
        self.linear_decay = linear_decay
        

        
        # Calculate output shape of D (PatchGAN)
        patch = int(self.load_size / 2**4 )
        self.disc_patch = (patch, patch, 1)
        
        self.gen_f = 64
        self.disc_f = 64
        
        self.learning_rate_initial = 0.0002
        self.learning_rate = self.learning_rate_initial
        
        optimizer_gen = Adam(self.learning_rate,0.5)
        optimizer_disc = Adam(self.learning_rate,0.5)
        
        # Loss weights
        self.lambda_cycle = 10                   # Cycle-consistency loss
        self.lambda_id = 0.5 * self.lambda_cycle    # Identity mapping loss
        
        # Build and compile the discriminators
        self.d_X = self.build_discriminator()
        self.d_X.summary()
        
        self.d_Y = self.build_discriminator()
        
               #use the continue_training flag to resume the training
        if continue_training:
            
            self.d_X.load_weights('models\\cyclegan_%s_discriminator_X_weights-v3.h5'% (self.dataset_name))
            self.d_Y.load_weights('models\\cyclegan_%s_discriminator_Y_weights-v3.h5'% (self.dataset_name))
        
        self.d_X.compile(loss='mse',
            optimizer=optimizer_disc,loss_weights=[.25])
        self.d_Y.compile(loss='mse',
            optimizer=optimizer_disc,loss_weights=[.25])
        
        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------
        
        self.g_G = self.build_generator()
        self.g_G.summary()
        
        self.g_F = self.build_generator()
        
        if continue_training:

            self.g_G.load_weights('models\\cyclegan_%s_generator_G_weights-v3.h5'% (self.dataset_name))
            self.g_F.load_weights('models\\cyclegan_%s_generator_F_weights-v3.h5'% (self.dataset_name))
        
        # input image from domain X and domain Y
        img_X = Input(shape = self.img_shape)
        img_Y = Input(shape = self.img_shape)
        # Translate images to the other domain
        fake_Y = self.g_G(img_X)
        fake_X = self.g_F(img_Y)
#         Translate images back to original domain
        reconstruct_X = self.g_F(fake_Y)
        reconstruct_Y = self.g_G(fake_X)
#         Identity mapping of images
        img_X_identity = self.g_F(img_X)
        img_Y_identity = self.g_G(img_Y)

        

        
        #valid_X, valid_Y is the output of the generator model which will give high scores for the generated image        
        valid_X = self.d_X(fake_X)
        valid_Y = self.d_Y(fake_Y)


                
        #combine the model and train the g_F

        self.generator_training_model_F = Model(inputs=[img_X, img_Y], outputs=[valid_X, img_X_identity, \
                                                                            reconstruct_X])
        

        self.d_X.trainable = False
        self.d_Y.trainable = False
        
        #turn g_G not trainable during training of g_F
        self.g_G.trainable = False
        self.g_F.trainable = True  
        
        
        self.generator_training_model_F.compile(loss=['mse','mae',\
                                             'mae'],
                                                      loss_weights=[1,self.lambda_id,\
                                                                    self.lambda_cycle],
                                                      optimizer = optimizer_gen)
        
        
        #combine the model and train the g_G
        self.generator_training_model_G = Model(inputs=[img_Y, img_X], outputs=[valid_Y, img_Y_identity, \
                                                                    reconstruct_Y])
        
        #turn discriminator not trainable during training of generators
        self.d_X.trainable = False
        self.d_Y.trainable = False
        #turn g_G not trainable during training of g_F
        self.g_G.trainable = True
        self.g_F.trainable = False         
    
        
        
        
        self.generator_training_model_G.compile(loss=['mse','mae',\
                                             'mae'],
                                                      loss_weights=[1,self.lambda_id,\
                                                                    self.lambda_cycle],
                                                      optimizer = optimizer_gen)

                

    def load_data(self,file_pattern):
        return glob.glob(file_pattern)

    def read_image(self,fn):
        try:
            im = Image.open(fn).convert('RGB')
            im = im.resize((self.load_size,self.load_size),Image.BILINEAR)
            img = (np.array(im)/255 -0.5) *2
            return img
        except Exception as e:
            pass
        
    
    def preprocess(self,dataset_name):
        
        print("loading the training data in the {} dataset...".format(dataset_name))
        
        # use all of the images in both train folder and test folder for training
        
        train_A_names = self.load_data('E:\\machine_learning_image_data\\cycle_gan\\datasets\\{}\\trainA\\*.jpg'.format\
                                     (dataset_name))
        train_B_names = self.load_data('E:\\machine_learning_image_data\\cycle_gan\\datasets\\{}\\trainB\\*.jpg'.format\
                             (dataset_name))
        test_A_names = self.load_data('E:\\machine_learning_image_data\\cycle_gan\\datasets\\{}\\testA\\*.jpg'.format\
                                     (dataset_name))
        test_B_names = self.load_data('E:\\machine_learning_image_data\\cycle_gan\\datasets\\{}\\testB\\*.jpg'.format\
                             (dataset_name))
        
        train_A = [self.read_image(train_A_names[j]) for j in range(len(train_A_names))]
        train_B = [self.read_image(train_B_names[j]) for j in range(len(train_B_names))]
        train_A = np.array(train_A)
        train_B = np.array(train_B)
        
        test_A = [self.read_image(test_A_names[j]) for j in range(len(test_A_names))]
        test_B = [self.read_image(test_B_names[j]) for j in range(len(test_B_names))]
        test_A = np.array(test_A)
        test_B = np.array(test_B)
        
        
        print("...loading finished.")
        
        return train_A,train_B,test_A,test_B
    
    def build_generator(self):
        
        def conv2d(layer_input, filters, f_size=3,strides=2,padding = 'same'):
            
            init = RandomNormal(stddev=0.02)
            d = Conv2D(filters, kernel_size=f_size, strides=strides, padding=padding,kernel_initializer=init)(layer_input)            
            d = InstanceNormalization(axis=-1)(d)
            d = Activation('relu')(d)

            return d
        

        
        def residual(layer_input,filters):
            init = RandomNormal(stddev=0.02)
            # first layer
            x = ReflectionPadding2D((1,1))(layer_input)
            x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='valid',kernel_initializer=init)(x)
            x = InstanceNormalization(axis=-1)(x)
            x = Activation('relu')(x)
            # second layer
            x = ReflectionPadding2D((1, 1))(x)
            x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='valid',kernel_initializer=init)(x)
            x = InstanceNormalization(axis=-1)(x)
            # merge
            x = layers.add([x, layer_input])
            return x
        
        
        
        def deconv2d(layer_input, filters, f_size=3):
            init = RandomNormal(stddev=0.02)
#             u = UpSampling2D(size=2)(layer_input)
#             u = Conv2D(filters, kernel_size=f_size, strides=1, padding='valid',kernel_initializer=init)(u)
#             u = ReflectionPadding2D((3,3))(layer_input)
            u = Conv2DTranspose(filters, kernel_size=f_size, strides=2, padding='same',kernel_initializer=init)(layer_input)
            u = InstanceNormalization(axis=-1)(u)
            u = Activation('relu')(u)
            
                
            return u
        
        input_img = Input(shape=self.img_shape)
        
        model = ReflectionPadding2D((3, 3))(input_img)
        model = conv2d(model,self.gen_f,f_size=7,strides=1,padding = 'valid')
        model = conv2d(model,self.gen_f*2)
        model = conv2d(model,self.gen_f*4)
        
        #6 residual blocks
        for _ in range(self.num_residual):
            model = residual(model,self.gen_f*4)
        
        model = deconv2d(model,self.gen_f*2)
        model = deconv2d(model,self.gen_f)     
        model = ReflectionPadding2D((3, 3))(model)
        model = conv2d(model,filters = 3,f_size=7,strides=1,padding='valid')
        model =InstanceNormalization(axis=-1)(model)
        
        output_img = Activation('tanh')(model)
        
        return Model(input_img,output_img)
        
    def build_discriminator(self):
        
        def disc_layer(layer_input, filters, f_size=4, strides=2,normalization=True):
            
            init = RandomNormal(stddev=0.02)
            d = Conv2D(filters, kernel_size=f_size, strides=strides, padding='same',kernel_initializer=init)(layer_input)
            if normalization:                
                d = InstanceNormalization(axis=-1)(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d
        
        input_img = Input(shape=self.img_shape)
        
        model = disc_layer(input_img, self.disc_f, normalization=False)
        model = disc_layer(model, self.disc_f*2)
        model = disc_layer(model, self.disc_f*4)
        model = disc_layer(model, self.disc_f*8)
#         model = disc_layer(model, self.disc_f*8,strides = 1)
        
        
        output = Conv2D(1, kernel_size=4, strides=1, padding='same')(model)


        
        return Model(input_img,output)
        
    def train(self,iterations,batch_size=1,sample_interval=50,resume_step=0):
        
        #time counter
        start_time = datetime.datetime.now()
        
        # Adversarial loss ground truths
        valid = np.ones((batch_size,)+self.disc_patch)
        fake = np.zeros((batch_size,)+self.disc_patch)
        
        print('valid shape: ',valid.shape)
        
        self.train_A, self.train_B,self.test_A,self.test_B = self.preprocess(self.dataset_name)

        for iteration in range(resume_step,iterations):
            
            
                
#             idx_A = np.random.randint(0, self.train_A.shape[0], batch_size)
#             idx_B = np.random.randint(0, self.train_B.shape[0], batch_size)

            #sample the image from either training array or test array into the img variable. 
            #the relative proportion of train and test set need to be considered,otherwise the sampling may be unbalanced.
            
            rand_A=np.random.random()
            
            if rand_A<(self.train_A.shape[0]/(self.train_A.shape[0]+self.test_A.shape[0])):
                idx_A = np.random.randint(0, self.train_A.shape[0], batch_size)
                imgs_A = self.train_A[idx_A]
            else:
                idx_A = np.random.randint(0, self.test_A.shape[0], batch_size)
                imgs_A = self.test_A[idx_A]
                
            rand_B=np.random.random()
            
            if rand_B<(self.train_B.shape[0]/(self.train_B.shape[0]+self.test_B.shape[0])):
                idx_B = np.random.randint(0, self.train_B.shape[0], batch_size)
                imgs_B = self.train_B[idx_B]
            else:
                idx_B = np.random.randint(0, self.test_B.shape[0], batch_size)
                imgs_B = self.test_B[idx_B]

#             imgs_A = self.train_A[idx_A]
#             imgs_B = self.train_B[idx_B]


            # Translate images to opposite domain
            fake_B = self.g_G.predict(imgs_A)
            fake_A = self.g_F.predict(imgs_B)
            
            # linear decay of learning rate for the last 100 epoch
            if self.linear_decay ==True:
                self.learning_rate = self.lr_linear_decay(iteration,iterations)
            

            
            # ------------------
            #  Train Generators - g_F
            # ------------------
            
#             self.d_X.trainable = False
#             self.d_Y.trainable = False
#             self.g_G.trainable = False
#             self.g_F.trainable = True

                
            g_loss_F = self.generator_training_model_F.train_on_batch([imgs_A,imgs_B],[valid,imgs_A,\
                                                                            imgs_A])
            # ----------------------
            #  Train Discriminators - d_X after training of g_G
            # ----------------------
            
#             self.d_X.trainable = True
#             self.d_Y.trainable = True
#             self.g_G.trainable = False
#             self.g_F.trainable = False   

            dX_loss_real = self.d_X.train_on_batch(imgs_A, valid)
            dX_loss_fake = self.d_X.train_on_batch(fake_A, fake)
            dX_loss = np.add(dX_loss_real, dX_loss_fake)            

            
            # ------------------
            #  Train Generators - g_G
            # ------------------
            
#             self.d_X.trainable = False
#             self.d_Y.trainable = False
#             self.g_G.trainable = True
#             self.g_F.trainable = False    

            g_loss_G = self.generator_training_model_G.train_on_batch([imgs_B,imgs_A],[valid,imgs_B,\
                                                            imgs_B])

            # ----------------------
            #  Train Discriminators - d_Y after training of g_F
            # ----------------------
            
            self.d_X.trainable = True
            self.d_Y.trainable = True
            self.g_G.trainable = False
            self.g_F.trainable = False   
            
#             Train the discriminators (original images = real / generated = Fake)
            dY_loss_real = self.d_Y.train_on_batch(imgs_B, valid)
            dY_loss_fake = self.d_Y.train_on_batch(fake_B, fake)                
            dY_loss = np.add(dY_loss_real, dY_loss_fake)


            
            # combine g and d losses
            g_loss = np.add(g_loss_F, g_loss_G)
            d_loss = np.add(dX_loss, dY_loss)

            
            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress

#             print ("%d [Dloss: %f] [G loss: %05f] time: %s"\

#                    % (iteration, d_loss, g_loss[0], elapsed_time))
                
            if iteration % sample_interval == 0:
                
                print ("%d [Dloss: %f] [G loss: %05f] time: %s"\

                       % (iteration, d_loss, g_loss[0], elapsed_time))
                print ("current learning rate:%05f" %self.learning_rate)

                self.sample_plot(iteration)

            # save the generator model each interval of 2000

            if iteration % 2000 == 0:
                
                os.makedirs('models', exist_ok=True)
                #save the generator model
                self.g_G.save('models\\cyclegan_%s_generator_G-v3.h5'% (self.dataset_name))
                self.g_F.save('models\\cyclegan_%s_generator_F-v3.h5'% (self.dataset_name))

                #save models weights
                self.d_X.save_weights('models\\cyclegan_%s_discriminator_X_weights-v3.h5'% (self.dataset_name))
                self.d_Y.save_weights('models\\cyclegan_%s_discriminator_Y_weights-v3.h5'% (self.dataset_name))
                self.g_G.save_weights('models\\cyclegan_%s_generator_G_weights-v3.h5'% (self.dataset_name))
                self.g_F.save_weights('models\\cyclegan_%s_generator_F_weights-v3.h5'% (self.dataset_name))
                
        print('training finished.')

    #calculate the lr based on the current iteration number
    def lr_linear_decay(self,iteration,iterations):
        
        lr = self.learning_rate_initial
        # 80000 is the number after which the decay starts.it decays from lr_initial to 0
        decay_start_step = 80000
        lr = self.learning_rate_initial * (1 - (iteration-decay_start_step) / (iterations - decay_start_step) )
        return lr

    def sample_plot(self, iteration):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3
        idx_A = np.random.randint(0, self.train_A.shape[0], size=1)
        idx_B = np.random.randint(0, self.train_B.shape[0], size=1)
        sample_imgs_A = self.train_A[idx_A]
        sample_imgs_B = self.train_B[idx_B]
        
        # Translate images to the other domain
        sample_fake_B = self.g_G.predict(sample_imgs_A)
        sample_fake_A = self.g_F.predict(sample_imgs_B)
        # Translate back to original domain
        sample_reconstruct_A = self.g_F.predict(sample_fake_B)
        sample_reconstruct_B = self.g_G.predict(sample_fake_A)

        gen_imgs = np.concatenate([sample_imgs_A, sample_fake_B, sample_reconstruct_A,sample_imgs_B, sample_fake_A, sample_reconstruct_B])
        print('gen_imgs shape: ',gen_imgs.shape)
        
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        # 20 x 15 is about 1440 x 1080 pixels
        fig.set_size_inches(20, 15, forward=True)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j],fontsize = 50)
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d.png" % (self.dataset_name, iteration))
        plt.close()

    def generate_image(self, generator='G'):
        
        
        print('sample generation started...')
        imgs_folder = 'samples'
        #make dir named imgs_folder
        os.makedirs(imgs_folder, exist_ok=True)
        # read images into np array
        imgs_names = self.load_data('{}/*.jpeg'.format\
                                     (imgs_folder))
        imgs = [self.read_image(imgs_names[j]) for j in range(len(imgs_names))]
        imgs = np.array(imgs)


        if generator == 'G':
            
            print('use generator G')
            sample_fake_img = self.g_G.predict(imgs)
            
        elif generator == 'F':
            print('use generator F')
            sample_fake_img = self.g_F.predict(imgs)
            
        for k in range(imgs.shape[0]):
            

            
            gen_imgs = np.concatenate([imgs[k].reshape((1,self.load_size,self.load_size,3)), sample_fake_img[k].reshape((1,self.load_size,self.load_size,3))])
            

             # Rescale images 0 - 1
            gen_imgs = 0.5 * gen_imgs + 0.5
            gen_imgs = np.clip(gen_imgs, 0.0, 1.0)
            
            titles = ['Original', 'Translated']
            fig, axs = plt.subplots(1, 2)
            fig.set_size_inches(10, 8, forward=True)
            cnt = 0
            
            for j in range(2):
                axs[j].imshow(gen_imgs[cnt])
                axs[j].set_title(titles[j],fontsize = 50)
                axs[j].axis('off')
                cnt += 1
            fig.savefig("%s/translated_%d.png" % (imgs_folder, k))
            plt.close()
            
#             plt.imshow(gen_imgs[0])
#             plt.show()
#             plt.close()
            
        print('...sample generation finished')        
    
        
# reflection padding taken from
# https://github.com/fastai/courses/blob/master/deeplearning2/neural-style.ipynb
class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')




In [None]:
if __name__ == '__main__':
    gan = CycleGAN(continue_training = True,linear_decay=True)
#     gan.train(iterations=160001, batch_size=1, sample_interval=500,resume_step=146000)

# generate sample image
#     gan.generate_image(generator='F')

In [None]:
#set resume_point =  in gan.train() to resume the training