### Colorization of panchromatic space images with Generative Adversarial Network

Author: Patrycja Cieplicka

Date: 12 Jan 2020

Implementation of class Generative Adversarial Network for second model colorization

## Init

In [1]:
from ipynb.fs.full.packages import *

Using TensorFlow backend.


In [2]:
np.random.seed(1)
set_random_seed(1)


In [3]:
class GAN_two():
    def __init__(self):
        
        self.g_input_shape = (128,128,1)
        self.d_input_shape = (128,128,3)
        
        #Compiling generator
        self.generator = self.build_generator()
        opt = Adam(lr = 0.0001)
        self.generator.compile(loss='binary_crossentropy', optimizer=opt)
        print('Generator Summary...')
        print(self.generator.summary())
        
        #Compiling discriminator
        self.discriminator = self.build_discriminator()
        opt = Adam(lr = 0.0001)
        self.discriminator.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
        print('Discriminator Summary...')
        print(self.discriminator.summary())

        #Initialize GAN
        gan_input = Input(shape=self.g_input_shape)
        img_color = self.generator(gan_input)
        self.discriminator.trainable = False
        real_or_fake = self.discriminator([gan_input, img_color])
        self.gan = Model(inputs = gan_input, outputs = real_or_fake)
        
        #Compiling GAN
        opt = Adam(lr = 0.001)
        self.gan.compile(loss='binary_crossentropy', optimizer=opt)
        print('\n')
        print('GAN summary...')
        print(self.gan.summary())

    def build_generator(self):
        #Generator model
        
        g_input = Input(shape=self.g_input_shape)
        
        #128 x 128
        conv1 = Conv2D(64, (3, 3), padding='same', strides=2)(g_input)
        conv1 = BatchNormalization()(conv1)
        conv1 = Activation('relu')(conv1)

        conv2 = Conv2D(128, (3, 3), padding='same', strides=1)(conv1)
        conv2 = BatchNormalization()(conv2)
        conv2 = Activation('relu')(conv2)

        #64 x 64
        conv3 = Conv2D(128, (3, 3), padding='same', strides=2)(conv2)
        conv3 = BatchNormalization()(conv3)
        conv3 = Activation('relu')(conv3)

        conv4 = Conv2D(256, (3, 3), padding='same', strides=1)(conv3)
        conv4 = BatchNormalization()(conv4)
        conv4 = Activation('relu')(conv4)
        
         #32 x 32
        conv5 = Conv2D(256, (3, 3), padding='same', strides=2)(conv4)
        conv5 = BatchNormalization()(conv5)
        conv5 = Activation('relu')(conv5)

        conv6 = Conv2D(512, (3, 3), padding='same', strides=1)(conv5)
        conv6 = BatchNormalization()(conv5)
        conv6 = Activation('relu')(conv5)

        #16 x16
        conv7 = Conv2D(512, (3, 3), padding='same', strides=2)(conv6)
        conv7 = BatchNormalization()(conv7)
        conv7 = Activation('relu')(conv7)
        
        #32 x 32
        conv8 = UpSampling2D(size=(2, 2))(conv7)
        conv8 = Conv2D(512, (3, 3), padding='same')(conv8)
        conv8 = BatchNormalization()(conv8)
        conv8 = Activation('relu')(conv8)
        conv8 = Concatenate(axis=-1)([conv8,conv6])

        conv9 = Conv2D(512, (3, 3), padding='same')(conv8)
        conv9 = BatchNormalization()(conv9)
        conv9 = Activation('relu')(conv9)

        #64 x 64
        conv10 = UpSampling2D(size=(2, 2))(conv9)
        conv10 = Conv2D(256, (3, 3), padding='same')(conv10)
        conv10 = BatchNormalization()(conv10)
        conv10 = Activation('relu')(conv10)
        conv10 = Concatenate(axis=-1)([conv10,conv4])

        conv11 = Conv2D(256, (3, 3), padding='same')(conv10)
        conv11 = BatchNormalization()(conv11)
        conv11 = Activation('relu')(conv11)

        #128 x 128
        conv12 = UpSampling2D(size=(2, 2))(conv11)
        conv12 = Conv2D(128, (3,3), padding='same')(conv12)
        conv12 = BatchNormalization()(conv12)
        conv12 = Activation('relu')(conv12)
        conv12 = Concatenate(axis=-1)([conv12,conv2])

        conv13 = Conv2D(128, (3, 3), padding='same')(conv12)
        conv13 = BatchNormalization()(conv13)
        conv13 = Activation('relu')(conv13)
        
        #256 x 256
        conv14 = UpSampling2D(size=(2, 2))(conv13)
        conv14 = Conv2D(64, (3,3), padding='same')(conv14)
        conv14 = BatchNormalization()(conv14)
        conv14 = Activation('relu')(conv14)

        conv15 = Conv2D(3, (3, 3), padding='same')(conv14)
        conv15 = Activation('tanh')(conv15)

        model = Model(inputs=g_input,outputs=conv15)
        return model

    def build_discriminator(self):
        #Dicriminator model
        
        d_input_lab = Input(shape=self.d_input_shape)
        d_input_l = Input(shape=self.g_input_shape)
        
        #128 x128
        disc_conv1 = concatenate([d_input_l, d_input_lab], axis=3) 
        disc_conv1 = Conv2D(32, (3, 3), padding='same', strides=1)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)

        disc_conv1 = Conv2D(64, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)

        #64x64
        disc_conv1 = Conv2D(128, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)

        #32x32
        disc_conv1 = Conv2D(128, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)

        #16x16
        disc_conv1 = Conv2D(256, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)
        
        #8x8
        disc_conv1 = Conv2D(256, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)
        
        #4x4
        disc_conv1 = Conv2D(512, (3, 3), padding='same', strides=2)(disc_conv1)
        disc_conv1 = BatchNormalization()(disc_conv1)
        disc_conv1 = LeakyReLU()(disc_conv1)


        final = Dropout(.4)(disc_conv1)
        final = Flatten()(final)
        final = Dense(1)(final)
        final = Activation('sigmoid')(final)
        
        model = Model(inputs = [d_input_l, d_input_lab], outputs = final)
        
        return model

    def train(self, X_train_L, X_train_LAB, epochs):
        
        #Training loop for GAN.
        #Inputs: X_train L channel, X_train AB channels, number of epochs.
        #Outputs: Models are saved and loss/acc plots saved.

        g_losses = []
        d_losses = []
        d_acc = []
        X_train = X_train_L
        n = len(X_train)
        # real pictures label - 1, fake pictures label - 0
        y_train_fake = np.zeros([n,1])
        y_train_fake_s = np.full([n,1], 0.0)
        y_train_real_s = np.full([n,1], 0.9)
        y_train_real = np.ones([n,1])
        
        for e in range(epochs):
            print("Epochs:")
            print(e)
            
           #Generate images
            generated_images = self.generator.predict(X_train, verbose=1)

            #Train Discriminator - first real images, next fake
            d_loss  = self.discriminator.fit(x=[X_train_L, X_train_LAB], y=y_train_real_s,  batch_size=8, epochs=1, shuffle=True)
            #noisy labels
            if e % 3 == 2:
                noise = np.random.rand(n,128,128,3) * 2 -1
                d_loss = self.discriminator.fit(x=[X_train_L, noise], y=y_train_fake, batch_size=8, epochs=1)
            d_loss = self.discriminator.fit(x=[X_train_L,generated_images], y=y_train_fake, batch_size=8, epochs=1, shuffle=True)
            d_losses.append(d_loss.history['loss'][-1])
            d_acc.append(d_loss.history['acc'][-1])
            print('d_loss:', d_loss.history['loss'][-1])

            #train GAN on grayscaled images , set output class to colorized
            g_loss = self.gan.fit(x=X_train, y=y_train_real, batch_size=8, epochs=1)

            #Record Losses/Acc
            g_losses.append(g_loss.history['loss'][-1])
            print('Generator Loss: ', g_loss.history['loss'][-1])
            disc_acc = d_loss.history['acc'][-1]
            print("Discriminator Accuracy: ", disc_acc)

            if e % 5 == 4:
                print(e + 1,"batches done")
            if e % 250 == 0:
                self.plot_losses(g_losses,'Generative Loss UNet Semi', e)
                self.plot_losses(d_acc, 'Discriminative Accuracy Unet Semi',e)
                self.generator.save('../unet_model_semi_batch_' + str(e)+'.h5')
                self.discriminator.save('../u_disc_model_semi_batch_' + str(e)+'.h5')

        #save outputs
        self.plot_losses(g_losses,'Generative Loss', epochs)
        self.plot_losses(d_acc, 'Discriminative Accuracy',epochs)
        self.generator.save('../unet_model_full_batch_' + str(epochs)+'.h5')
        self.discriminator.save('../u_disc_model_full_batch_' + str(epochs)+'.h5')

    def plot_losses(self, metric, label, epochs):
        
        #Plot the loss/acc of the generator/discriminator.
        #Inputs: metric, label of graph, number of epochs (for file name)
        
        plt.plot(metric, label=label)
        plt.title('GAN Accuracy and Loss Over ' + str(epochs) + ' Epochs')
        plt.savefig('../plot_' + str(epochs) + '_epochs.png')