### Colorization of panchromatic space images with Generative Adversarial Network

Author: Patrycja Cieplicka

Date: 12 Jan 2020

Implementation of class Generative Adversarial Network for main model colorization

In [5]:
from ipynb.fs.full.packages import *

In [6]:
#initialize random numbers generator 
np.random.seed(1)
set_random_seed(1)

In [9]:
class GAN_ab():
    def __init__(self, lr_gen, lr_disc):
        
        #input shape for generator
        self.g_input_shape = (128,128,1)
        #output shape for generator
        self.g_output_shape = (128,128,2)
        
        #Compiling generator
        print("Generator architecture")
        self.generator = self.build_generator()
        opti = Adam(lr=lr_gen)
        self.generator.compile(loss='binary_crossentropy', optimizer=opti)
        print(self.generator.summary())
        
        #Compiling discriminator
        print('Discriminator architecture')
        self.discriminator = self.build_discriminator()
        opti = Adam(lr=lr_disc)
        self.discriminator.compile(loss='binary_crossentropy', optimizer=opti, metrics=['accuracy'])
        print(self.discriminator.summary())

        #Initialize GAN
        gan_input = Input(shape=self.g_input_shape)
        img_color = self.generator(gan_input)
        self.discriminator.trainable = False
        real_or_fake = self.discriminator([gan_input, img_color])
        self.gan = Model(inputs = gan_input, outputs = real_or_fake)
        
        #Compiling GAN
        opti = Adam(lr=lr_gen)
        self.gan.compile(loss='binary_crossentropy', optimizer=opti)
        print('\n')
        print('GAN summary...')
        print(self.gan.summary())

    def build_generator(self):
        #Generator architecture
        
        g_input = Input(shape=self.g_input_shape)
        
        #128 x 128
        conv1 = Conv2D(64, (3, 3), padding='same', strides=2)(g_input)
        conv1 = BatchNormalization()(conv1)
        conv1 = Activation('relu')(conv1)
        
        #64 x 64
        conv2 = Conv2D(128, (3, 3), padding='same', strides=1)(conv1)
        conv2 = BatchNormalization()(conv2)
        conv2 = Activation('relu')(conv2)

        conv3 = Conv2D(128, (3, 3), padding='same', strides=2)(conv2)
        conv3 = BatchNormalization()(conv3)
        conv3 = Activation('relu')(conv3)

        #32 x 32
        conv4 = Conv2D(256, (3, 3), padding='same', strides=1)(conv3)
        conv4 = BatchNormalization()(conv4)
        conv4 = Activation('relu')(conv4)
        
        conv5 = Conv2D(256, (3, 3), padding='same', strides=2)(conv4)
        conv5 = BatchNormalization()(conv5)
        conv5 = Activation('relu')(conv5)

        #16 x16
        conv6 = Conv2D(512, (3, 3), padding='same', strides=1)(conv5)
        conv6 = BatchNormalization()(conv5)
        conv6 = Activation('relu')(conv5)

        conv7 = Conv2D(512, (3, 3), padding='same', strides=2)(conv6)
        conv7 = BatchNormalization()(conv7)
        conv7 = Activation('relu')(conv7)
        
        
        conv8 = UpSampling2D(size=(2, 2))(conv7)
        conv8 = Conv2D(512, (3, 3), padding='same')(conv8)
        conv8 = BatchNormalization()(conv8)
        conv8 = Activation('relu')(conv8)
        conv8 = Concatenate(axis=-1)([conv8,conv6])

        #16 x 16
        conv9 = Conv2D(512, (3, 3), padding='same')(conv8)
        conv9 = BatchNormalization()(conv9)
        conv9 = Activation('relu')(conv9)

        conv10 = UpSampling2D(size=(2, 2))(conv9)
        conv10 = Conv2D(256, (3, 3), padding='same')(conv10)
        conv10 = BatchNormalization()(conv10)
        conv10 = Activation('relu')(conv10)
        conv10 = Concatenate(axis=-1)([conv10,conv4])
        
        #32 x 32
        conv11 = Conv2D(256, (3, 3), padding='same')(conv10)
        conv11 = BatchNormalization()(conv11)
        conv11 = Activation('relu')(conv11)

        conv12 = UpSampling2D(size=(2, 2))(conv11)
        conv12 = Conv2D(128, (3,3), padding='same')(conv12)
        conv12 = BatchNormalization()(conv12)
        conv12 = Activation('relu')(conv12)
        conv12 = Concatenate(axis=-1)([conv12,conv2])
        
        #64 x 64
        conv13 = Conv2D(128, (3, 3), padding='same')(conv12)
        conv13 = BatchNormalization()(conv13)
        conv13 = Activation('relu')(conv13)
        
        conv14 = UpSampling2D(size=(2, 2))(conv13)
        conv14 = Conv2D(64, (3,3), padding='same')(conv14)
        conv14 = BatchNormalization()(conv14)
        conv14 = Activation('relu')(conv14)
        
        #128 x 128
        conv15 = Conv2D(2, (3, 3), padding='same')(conv14)
        conv15 = Activation('tanh')(conv15)

        model = Model(inputs=g_input,outputs=conv15)
        return model

    def build_discriminator(self):
        #Dicriminator architecture
        
        d_input_l = Input(shape=self.g_input_shape)
        d_input_ab = Input(shape=self.g_output_shape)
        
        #128 x128
        disc_conv1 = concatenate([d_input_l, d_input_ab], axis=3) 
        disc_conv1 = Conv2D(32, (3, 3), padding='same', strides=1)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)

        disc_conv2 = Conv2D(64, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv2 = BatchNormalization()(disc_conv2)
        disc_conv2 = LeakyReLU()(disc_conv2)

        #64x64
        disc_conv3 = Conv2D(128, (3, 3), padding='same', strides=2)(disc_conv2)
        disc_conv3 = BatchNormalization()(disc_conv3)
        disc_conv3 = LeakyReLU()(disc_conv3)

        #32x32
        disc_conv4 = Conv2D(128, (3, 3), padding='same', strides=2)(disc_conv3)
        disc_conv4 = BatchNormalization()(disc_conv4)
        disc_conv4 = LeakyReLU()(disc_conv4)

        #16x16
        disc_conv5 = Conv2D(256, (3, 3), padding='same', strides=2)(disc_conv4)
        disc_conv5 = BatchNormalization()(disc_conv5)
        disc_conv5 = LeakyReLU()(disc_conv5)
        
        #8x8
        disc_conv6 = Conv2D(256, (3, 3), padding='same', strides=2)(disc_conv5)
        disc_conv6 = BatchNormalization()(disc_conv6)
        disc_conv6 = LeakyReLU()(disc_conv6)
        
        #4x4
        disc_conv7 = Conv2D(512, (3, 3), padding='same', strides=2)(disc_conv6)
        disc_conv7 = BatchNormalization()(disc_conv7)
        disc_conv7 = LeakyReLU()(disc_conv7)

        final = Dropout(.4)(disc_conv7)
        final = Flatten()(final)
        final = Dense(1)(final)
        final = Activation('sigmoid')(final)
        
        model = Model(inputs = [d_input_l, d_input_ab], outputs = final)
        
        return model

    def train(self, X_train_L, X_train_AB, epochs, label):
        
        #Training loop for GAN.
        #Inputs: X_train (L channel), X_train_AB (ab channels), epochs (number of epochs), 
        #label (label for true images)
        #Outputs: No outputs, models are saved to file

        g_losses = [] #genrator loss
        d_losses = [] #discriminator loss
        d_acc = [] #discriminator accurency
        
        X_train = X_train_L
        
        n = len(X_train) #number of train images

        y_train_fake = np.zeros([n,1]) # fake pictures label, always 0
        y_train_real = np.full([n,1], label) # real pictures label
        
        for e in range(epochs):
            print("Epochs:")
            print(e)
            
            # generator generates images
            generated_images = self.generator.predict(X_train, verbose=1) 

            #train discriminator
            #first - train on real images
            d_loss  = self.discriminator.fit(x=[X_train_L, X_train_AB], y=y_train_real,  batch_size=8, 
                                             epochs=1, shuffle=True)
            #second - every third epochs noisy labels
            if e % 3 == 2:
                noise = np.random.rand(n,128,128,2) * 2 -1
                d_loss = self.discriminator.fit(x=[X_train_L, noise], y=y_train_fake, 
                                                batch_size=8, epochs=1)
            #third - train on generated image
            d_loss = self.discriminator.fit(x=[X_train_L,generated_images], y=y_train_fake, batch_size=8, 
                                            epochs=1, shuffle=True)
            #save history
            d_losses.append(d_loss.history['loss'][-1])
            d_acc.append(d_loss.history['acc'][-1])
            print('d_loss:', d_loss.history['loss'][-1])

            #train generator
            #train GAN on panachromatic images, set output class to colorized
            g_loss = self.gan.fit(x=X_train, y=y_train_real, batch_size=8, epochs=1)

            #save history
            g_losses.append(g_loss.history['loss'][-1])
            print('Generator Loss: ', g_loss.history['loss'][-1])
            disc_acc = d_loss.history['acc'][-1]
            print("Discriminator Accuracy: ", disc_acc)

            if e % 5 == 4:
                print(e + 1,"batches done")
            
            #save model every 250th epoch
            if e % 250 == 0:
                self.plot_losses(g_losses,'Generative Loss', e)
                self.plot_losses(d_acc, 'Discriminative Accuracy',e)
                self.generator.save('../unet_model_semi_batch_' + str(e)+'.h5')
                self.discriminator.save('../u_disc_model_semi_batch_' + str(e)+'.h5')
                
        #save final model
        self.plot_losses(g_losses,'Generative Loss', epochs)
        self.plot_losses(d_acc, 'Discriminative Accuracy', epochs)
        self.generator.save('../unet_model_ab_full_batch_' + str(epochs)+'.h5')
        self.discriminator.save('../u_disc_model_ab_full_batch_' + str(epochs)+'.h5')

    def plot_losses(self, metric, label, epochs):
        
        #Plot and save the loss of the generator and accurancy of the discriminator
        #Inputs: metric, label of graph, number of epochs
        
        plt.plot(metric, label=label)
        plt.title('GAN Accuracy and Loss Over ' + str(epochs) + ' Epochs')
        plt.savefig('../plot_' + str(epochs) + '_epochs.png')