In [1]:
import sys
import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt
import logging

In [2]:
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.layers import Input, Concatenate, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from skimage.transform import resize
from tensorflow.keras.optimizers import Adam, RMSprop,SGD
from tensorflow.keras.initializers import RandomNormal
#from functools import partial

In [3]:
import keras

In [4]:
from tensorflow.keras.applications.vgg19 import VGG19

In [5]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [6]:
# gpu = tf.config.experimental.list_physical_devices('GPU')[0]
# tf.config.experimental.set_memory_growth(gpu,True)

# Build Perceptual loss

In [7]:
from tensorflow.keras.applications.vgg19 import VGG19
class Cut_VGG19:
    """
    Class object that fetches keras' VGG19 model trained on the imagenet dataset
    and declares <layers_to_extract> as output layers. Used as feature extractor
    for the perceptual loss function.
    Args:
        layers_to_extract: list of layers to be declared as output layers.
        patch_size: integer, defines the size of the input (patch_size x patch_size).
    Attributes:
        loss_model: multi-output vgg architecture with <layers_to_extract> as output layers.
    """
    
    def __init__(self, patch_size, layers_to_extract):
        self.patch_size = patch_size
        self.input_shape = (patch_size,) * 2 + (3,)
        self.layers_to_extract = layers_to_extract
        
        if len(self.layers_to_extract) > 0:
            self._cut_vgg()
    
    def _cut_vgg(self):
        """
        Loads pre-trained VGG, declares as output the intermediate
        layers selected by self.layers_to_extract.
        """
        
        vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.input_shape)
        vgg.trainable = False
        outputs = [vgg.layers[i].output for i in self.layers_to_extract]
        self.model = Model([vgg.input], outputs)

        
feature_extraction = Cut_VGG19(256,[5,9])   
feature_extraction.model.trainable = False


In [8]:
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
        return z

In [9]:
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [10]:
class WGANGP():
    def __init__(self):
        self.input_dim = (256,256,3)
        self.optimiser = 'rmsprop'
        self.z_dim = 256
        ################ Encoder Model ########################
        self.encoder_input = (256,256,1)
        self.encoder_conv_filters = [64,128,256,512]
        self.encoder_conv_kernel_size = [3,3,3,3]
        self.encoder_conv_strides = [2,2,2,3]
        self.encoder_batch_norm_momentum =  None
        self.encoder_activation = 'relu'
        self.encoder_dropout_rate = None
        self.encoder_learning_rate = 1e-5
        
        ################ Decoder Model ########################
        self.decoder_initial_dense_layer_size = (16,16,512)
        self.decoder_conv_filters = [256,128,64,1]
        self.decoder_conv_kernel_size = [3,3,3,3]
        self.decoder_conv_strides = [2,2,2,2]
        self.decoder_batch_norm_momentum =  None
        self.decoder_activation = 'relu'
        self.decoder_dropout_rate = None
        self.decoder_learning_rate = 1e-5

        ################ Generator Model #########################
        self.generator_initial_dense_layer_size = (4,4,256)
        self.generator_upsample = [1,1,1,1,1,1]
        self.generator_conv_filters = [128,64,32,16,8,3]
        self.generator_conv_kernel_size = [3,3,3,3,3,3]
        self.generator_conv_strides = [2,2,2,2,2,2]
        self.generator_batch_norm_momentum =  0.8
        self.generator_activation = 'leaky_relu'
        self.generator_dropout_rate = None
        self.generator_learning_rate = 1e-3
        ################ Discriminator Model ###########################
        self.discriminator_conv_filters = [8,16,32,64,128,256]
        self.discriminator_conv_kernel_size = [3,3,3,3,3,3]
        self.discriminator_conv_strides = [2,2,2,2,2,2]
        self.discriminator_batch_norm_momentum = None
        self.discriminator_activation = 'leaky_relu'
        self.discriminator_dropout_rate = None
        self.discriminator_learning_rate = 1e-3
        ###########################################
        self.weight_init = RandomNormal(mean=0., stddev=0.02)
        self.grad_weight = 10
        self.batch_size = 128
        ############################################
        self.n_layers_encoder = len(self.encoder_conv_filters)
        self.n_layers_decoder = len(self.decoder_conv_filters)
        self.n_layers_discriminator = len(self.discriminator_conv_filters)
        self.n_layers_generator = len(self.generator_conv_filters)
        ###############################################                               
        self.d_losses = []
        self.g_losses = []
        self.epoch = 0
        ###############################################
        self._build_encoder()
        self.encoder_model.summary()
        self._build_decoder()
        self.decoder_model.summary()
        ###############################################
        self._build_generator()
        self.generator.summary()
        self._build_discriminator()
        self.discriminator.summary()
        ###############################################
        self._build_adversarial()
        self.gan_model.summary()
        self.vae_model.summary()
        self.model.summary()
        LOG_DIR = "./logs/vae_gan.log"
        logging.basicConfig(filename=LOG_DIR,  
                    level=logging.DEBUG,
                    format="[%(asctime)s] [%(name)s] [%(message)s]",
                    filemode="w")
    ####################### Loss ###########################
    def wasserstein(self, y_true, y_pred):
        alpha = 100
        return -K.mean(y_true * y_pred)*alpha
    
    def get_perceptual_loss(self,y_true,y_pred):
        content_feature = feature_extraction.model(y_true)
        new_feature = feature_extraction.model(y_pred)
        perceptual_loss = 0
        weight = tf.constant([1/16,1/8], dtype = tf.float32)
        for i in range(len(new_feature)):
            perceptual_loss += weight[i]*K.mean(K.square(new_feature[i] - content_feature[i]))
        l2_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(y_true,y_pred))
        total_loss = perceptual_loss + 100*l2_loss
        return total_loss

    def log_normal_pdf(self,sample, mean, logvar, raxis=1):
        log2pi = tf.math.log(2. * np.pi)
        return tf.reduce_sum(
          -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
          axis=raxis)
    
    def VAE_reconstruct_loss(self, y_true, y_pred):
        cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_pred, labels=y_true)
        logpx_z = -tf.reduce_mean(cross_ent, axis=[1, 2, 3])
        return logpx_z
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def VAE_MC_loss(self,y_true, y_pred):
        mean_true = y_true[:,:256]
        log_var_true = y_true[:,256:]
        z_mean = y_pred[:,:256]
        z_log_var = y_pred[:,256:]
        z = self.reparameterize(z_mean,z_log_var)
        logpz = self.log_normal_pdf(z, mean_true, log_var_true)
        logqz_x = self.log_normal_pdf(z, z_mean, z_log_var)
        return -tf.reduce_mean(logpz - logqz_x) 
    
        
    ################# Activation layer #####################                                                                
    def get_activation(self, activation):
        if activation == 'leaky_relu':
            layer = LeakyReLU(alpha = 0.2)
        else:
            layer = Activation(activation)
        return layer
    ####################################################################
    #################### Build Encoder Model ###########################
    ####################################################################
    def _build_encoder(self):
        encoder_input_layer = Input(shape=self.encoder_input, name='encoder_input')
        x = encoder_input_layer
        for i in range(self.n_layers_encoder):
            x = Conv2D(
                filters = self.encoder_conv_filters[i]
                , kernel_size = self.encoder_conv_kernel_size[i]
                , strides = self.encoder_conv_strides[i]
                , padding = 'same'
                , name = 'encoder_conv_' + str(i)
                , kernel_initializer = self.weight_init
                ,activation='relu'
                )(x)

            if self.encoder_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.encoder_batch_norm_momentum)(x)
            #x = self.get_activation(self.encoder_activation)(x)
            if self.encoder_dropout_rate:
                x = Dropout(rate = self.encoder_dropout_rate)(x)
        x = Flatten()(x)
        
#         encoder_output = Dense(1, activation=None
#         , kernel_initializer = self.weight_init
#         )(x)
        z_mean = Dense(self.z_dim, name="z_mean",kernel_initializer = self.weight_init)(x)
        z_log_var = Dense(self.z_dim, name="z_log_var", kernel_initializer = self.weight_init)(x)
    
        z = Sampling()([z_mean, z_log_var])
        encoder_output = [z,z_mean,z_log_var]
        self.encoder_model = Model(encoder_input_layer,encoder_output ,name="Encoder")
    ####################################################################
    #################### Build Decoder Model ###########################
    ####################################################################
    def _build_decoder(self):
        decoder_input_layer = Input(shape=(self.z_dim,), name='decoder_input')
        x = decoder_input_layer
        x = Dense(np.prod(self.decoder_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)
        if self.decoder_batch_norm_momentum:
            x = BatchNormalization(momentum = self.decoder_batch_norm_momentum)(x)       
        x = self.get_activation(self.decoder_activation)(x)
        x = Reshape(self.decoder_initial_dense_layer_size)(x)
        for i in range(self.n_layers_decoder):
            if i == self.n_layers_decoder-1:
                x = Conv2DTranspose(
                filters = self.decoder_conv_filters[i]
                , kernel_size = self.decoder_conv_kernel_size[i]
                , strides = self.decoder_conv_strides[i]
                , padding = 'same'
                , name = 'decoder_conv_' + str(i)
                , kernel_initializer = self.weight_init
                ,activation='tanh'
                )(x)
            else:
                x = Conv2DTranspose(
                    filters = self.decoder_conv_filters[i]
                    , kernel_size = self.decoder_conv_kernel_size[i]
                    , strides = self.decoder_conv_strides[i]
                    , padding = 'same'
                    , name = 'decoder_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                    ,activation='relu'
                    )(x)
                #x = self.get_activation(self.decoder_activation)(x)
        
        decoder_output = x
        self.decoder_model = Model(decoder_input_layer,decoder_output)
    ####################################################################
    #################### Build Discriminator Model #####################
    ####################################################################
    def _build_discriminator(self):
        
        discriminator_input = Input(shape=self.input_dim, name='discriminator_input')
        x = discriminator_input
        
        for i in range(self.n_layers_discriminator):
            x = Conv2D(
                filters = self.discriminator_conv_filters[i]
                , kernel_size = self.discriminator_conv_kernel_size[i]
                , strides = self.discriminator_conv_strides[i]
                , padding = 'same'
                , name = 'discriminator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)

            if self.discriminator_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)
            x = self.get_activation(self.discriminator_activation)(x)
            if self.discriminator_dropout_rate:
                x = Dropout(rate = self.discriminator_dropout_rate)(x)

                
        x = Flatten()(x)
        
        discriminator_output = Dense(1, activation=None
        , kernel_initializer = self.weight_init
        )(x)
        x = self.get_activation('sigmoid')(x)
        self.discriminator = Model(discriminator_input, discriminator_output,name="Discriminator")
        
    ####################################################################
    #################### Build Generator Model #########################
    ####################################################################
    
    def _build_generator(self):
        ############  generator ###############
        generator_input_layer = Input(shape=(self.z_dim,), name='generator_input')
        x = generator_input_layer
        x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)
        
        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)       
        x = self.get_activation(self.generator_activation)(x)
        x = Reshape(self.generator_initial_dense_layer_size)(x)
        
        if self.generator_dropout_rate:
            x = Dropout(rate = self.generator_dropout_rate)(x)

        for i in range(self.n_layers_generator):

            if self.generator_upsample[i] == 2:
                x = UpSampling2D()(x)
                x = Conv2D(
                filters = self.generator_conv_filters[i]
                , kernel_size = self.generator_conv_kernel_size[i]
                , padding = 'same'
                , name = 'generator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)
            else:

                x = Conv2DTranspose(
                    filters = self.generator_conv_filters[i]
                    , kernel_size = self.generator_conv_kernel_size[i]
                    , padding = 'same'
                    , strides = self.generator_conv_strides[i]
                    , name = 'generator_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                    )(x)

            if i < self.n_layers_generator - 1:

                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)

                x = self.get_activation(self.generator_activation)(x)
                
            else:
                x = Activation('tanh')(x)

        generator_output = x
        self.generator = Model(generator_input_layer, generator_output,name="Generator")


    def get_opti(self, lr):
        if self.optimiser == 'adam':
            opti = Adam(lr=lr, beta_1=0.5)
        elif self.optimiser == 'rmsprop':
            opti = RMSprop(lr=lr)
        else:
            opti = Adam(lr=lr)

        return opti


    def set_trainable(self, m, val):
        m.trainable = val
        for l in m.layers:
            l.trainable = val

    def _build_adversarial(self):
                
        self.discriminator.compile(
            optimizer=self.get_opti(self.discriminator_learning_rate) 
            , loss = self.wasserstein
        )
        #######################################
        self.set_trainable(self.discriminator, False)
        input_gen = self.generator.input
        gen_fake_image_output = self.generator(input_gen)
        disc_output = self.discriminator(gen_fake_image_output)
        self.gan_model = Model(input_gen, [disc_output, gen_fake_image_output])
        #================ Model =========================
        self.gan_model.compile(
            optimizer = self.get_opti(self.encoder_learning_rate)
            , loss=[self.wasserstein,self.get_perceptual_loss]
            )

        self.set_trainable(self.discriminator, True)
        
        ############ BUILD VAE ################
        self.set_trainable(self.discriminator, False)
        self.set_trainable(self.generator, False)
        
        input_encoder = self.encoder_model.input
        encoder_output = self.encoder_model(input_encoder)
        z,z_mean,z_logvar = encoder_output[0],encoder_output[1],encoder_output[2]
        
        output_decoder = self.decoder_model(z)
        self.vae_model = Model(input_encoder,[output_decoder,tf.concat([z_mean,z_logvar],1)], name = "vae_model")
        self.vae_model.compile(
            optimizer=self.get_opti(self.encoder_learning_rate)
            , loss=[self.VAE_reconstruct_loss,self.VAE_MC_loss]
            )
        self.set_trainable(self.discriminator, True)
        self.set_trainable(self.generator, True)

        ############# BUILD FULL MULTITASK MODEL
        input_encoder = self.encoder_model.input
        encoder_output = self.encoder_model(input_encoder)
        
        z,z_mean,z_logvar = encoder_output[0],encoder_output[1],encoder_output[2]
        
        output_decoder = self.decoder_model(z)
        
        gen_fake_image_output = self.generator(z)
        
        disc_output = self.discriminator(gen_fake_image_output)
        
        self.model = Model(input_encoder, [gen_fake_image_output,output_decoder,tf.concat([z_mean,z_logvar],1)])
        #================ Model =========================
        self.model.compile(
            optimizer = self.get_opti(self.encoder_learning_rate)
            , loss=[self.get_perceptual_loss,self.VAE_reconstruct_loss,self.VAE_MC_loss]
            )

        
    def train_discriminator(self, x_train,Y_train, batch_size):

        valid = np.ones((batch_size,1))
        fake = -np.ones((batch_size,1))

        input_imgs = x_train
        true_imgs  = Y_train
        
        latent_code = self.encoder_model.predict(input_imgs)[0]
        gen_imgs = self.generator.predict(latent_code)
                
        d_loss_real =   self.discriminator.train_on_batch(true_imgs, valid)
        d_loss_fake =   self.discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)

        for l in self.discriminator.layers:
            weights = l.get_weights()
            weights = [np.clip(w, -0.01, 0.01) for w in weights]
            l.set_weights(weights)

        for l in self.discriminator.layers:
        
            weights = l.get_weights()
            if 'batch_normalization' in l.get_config()['name']:
                pass
                # weights = [np.clip(w, -0.01, 0.01) for w in weights[:2]] + weights[2:]
            else:
                weights = [np.clip(w, -0.01, 0.01) for w in weights]
            
            l.set_weights(weights)

        return [d_loss, d_loss_real, d_loss_fake]

    def train_generator(self,x_train,Y_train, batch_size):
        valid = np.ones((batch_size,1), dtype=np.float32)
        true_images = Y_train
        input_images = x_train
        
        latencode = self.encoder_model.predict(input_images)
        gans_loss = self.gan_model.train_on_batch(latencode[0], [valid,true_images])
        return gans_loss
    
    def train_vae(self,x_train,Y_train,batch_size):
        
        valid_z = np.zeros((batch_size,512),dtype=np.float32)
        
        return self.vae_model.train_on_batch(x_train,[x_train,valid_z])
    
    def train_full_model(self,x_train,Y_train,batch_size):
        valid_z = np.zeros((batch_size,512),dtype=np.float32)
        
        return self.model.train_on_batch(x_train,[Y_train,x_train,valid_z])

    
    def shuffle_data_batch(self,array_X,array_Y,batch_size):
        indices = np.arange(array_X.shape[0])
        np.random.shuffle(indices)
        array_X = array_X[indices]
        array_Y = array_Y[indices]
        
        def split_into_chunks(l, n):
            for i in range(0, l.shape[0], n):
                yield l[i:i + n]  
        array_X = split_into_chunks(array_X,batch_size)
        array_Y = split_into_chunks(array_Y,batch_size)
        
        return array_X,array_Y

        
    def train(self,x_train,Y_train,x_val,Y_val, batch_size, epochs, run_folder, print_every_n_batches = 10, n_critic = 5,using_generator = False):
        self.batch_size = batch_size
        self.n_steps = int(x_train.shape[0]/self.batch_size)-1
        for epoch in range(self.epoch, self.epoch + epochs):
            x_train_gan,Y_train_gan = self.shuffle_data_batch(x_train,Y_train,batch_size)
            
            for step in range(self.n_steps):
                x, Y = next(x_train_gan),next(Y_train_gan)
                for _ in range(n_critic):
                    d_loss = self.train_discriminator(x,Y, batch_size)
                    indices = np.random.shuffle(np.arange(self.batch_size))
                    x = np.squeeze(x[indices],axis=0)
                    Y = np.squeeze(Y[indices],axis=0)
                for _ in range(5):
                    vae_loss = self.train_vae(x,Y,batch_size)
                    indices = np.random.shuffle(np.arange(self.batch_size))
                    x = np.squeeze(x[indices],axis=0)
                    Y = np.squeeze(Y[indices],axis=0)
                g_loss = self.train_generator(x,Y, batch_size)
                model_loss = self.train_full_model(x,Y,batch_size)
                # Plot the progress

                self.d_losses.append(d_loss)
                self.g_losses.append(g_loss)

            # If at save interval => save generated image samples
            if epoch % print_every_n_batches == 0:
                logging.info(json.dumps({'epoch':epoch,'d_loss':d_loss,'g_loss':g_loss,'vae_loss':vae_loss,'full_model_loss':model_loss})) 
                print ("%d [D loss: (%.3f)(R %.3f, F %.3f)]  [Gan loss: %.3f, Per loss: %.3f,W: %.3f ]" % (epoch, d_loss[0], d_loss[1], d_loss[2], g_loss[0], g_loss[0], g_loss[1]))
                print("%d [VAE loss: %.3f (Re %.3f Mc %.3f) Full %.3f Per %.3f Re %.3f Mc %.3f ]"%(epoch,vae_loss[0],vae_loss[1],vae_loss[2],model_loss[0],model_loss[1],model_loss[2],model_loss[3]))
                print("==========================================================================================")
                self.sample_images(x_val,Y_val,run_folder)

            self.epoch+=1



    def sample_images(self,x_val,Y_val, run_folder):
        # Test
        r, c = 8, 4

        y_true = Y_val
        input_model = x_val
        
        latent_code = self.encoder_model.predict(input_model)
        gen_imgs = self.generator.predict(latent_code[0])
        
        gen_finger = self.decoder_model.predict(latent_code[0])
        # Perceptual loss
        #perceptloss = get_perceptual_loss(y_true,gen_imgs)

        indx = np.random.choice(y_true.shape[0], int(0.25*c*r) ,replace=False)
        finger_real = x_val[indx]
        face_real = 0.5*(y_true[indx]+1)
        
        #face_real = face_real[:,:,:,[2,1,0]]

        gen_imgs = 0.5 * (gen_imgs[indx] + 1)
        gen_imgs = np.clip(gen_imgs, 0, 1)
        #gen_imgs = gen_imgs[:,:,:,[2,1,0]]
        gen_finger = gen_finger[indx]

        fig, axs = plt.subplots(r, c, figsize=(15,30))
        #fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
        cnt = 0
        for i in range(r):
            for j in range(int(0.25*c)):
                axs[i,4*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
                axs[i,4*j].axis('off')
                axs[i,4*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
                axs[i,4*j+1].axis('off')
                axs[i,4*j+2].imshow(np.squeeze(gen_finger[cnt, :,:,:]),cmap ="gray")
                axs[i,4*j+2].axis('off')
                axs[i,4*j+3].imshow(np.squeeze(finger_real[cnt, :,:,:]),cmap ="gray")
                axs[i,4*j+3].axis('off')
                cnt += 1
        fig.savefig(os.path.join(run_folder, "images/sample_%d.png" % self.epoch))
        plt.close()
    

In [11]:
GAN = WGANGP()

Model: "Encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 128, 128, 64) 640         encoder_input[0][0]              
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 64, 64, 128)  73856       encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_2 (Conv2D)         (None, 32, 32, 256)  295168      encoder_conv_1[0][0]             
____________________________________________________________________________________________

# Load data train

In [20]:
DATA_DIR = '../Finger_enhancement/data_cropped/data_train/'
with open(os.path.join(DATA_DIR,"data_face_3571_train.pkl"), "rb") as input_file:
    data_train_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_3571_train.pkl"), "rb") as input_file:
    data_train_finger = pickle.load(input_file)

    
with open(os.path.join(DATA_DIR,"data_face_3571_val.pkl"), "rb") as input_file:
    data_val_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_3571_val.pkl"), "rb") as input_file:
    data_val_finger = pickle.load(input_file)
    

In [21]:
data_train_face.shape,data_train_finger.shape,data_val_face.shape,data_val_finger.shape

((2856, 256, 256, 3),
 (2856, 256, 256, 1),
 (715, 256, 256, 3),
 (715, 256, 256, 1))

In [22]:
print(np.max(data_train_finger),np.median(data_train_finger),np.min(data_train_finger))
print(np.max(data_train_face),np.median(data_train_face),np.min(data_train_face))

1 0.0 0
255 137.0 0


In [23]:
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import img_to_array

In [24]:
def augment_data_face_fingerprint(data_face,data_finger, num_augment_percent=0.5):
    num_data = data_face.shape[0]
    num_data_augment = int(data_face.shape[0]*num_augment_percent)
    index_data = np.random.choice(data_face.shape[0],num_data_augment,replace=False)
    
    data_aug_face = []
    data_aug_finger = []
    for i in index_data:
        data_aug_face.append(data_face[i,:,:,:])
        data_aug_finger.append(data_finger[i,:,:,:])
    data_aug_face = np.asarray(data_aug_face)    
    data_aug_finger = np.asarray(data_aug_finger) 

    # create image data augmentation generator
    datagen_face = ImageDataGenerator(horizontal_flip=True)
    datagen_face.fit(data_aug_face)
    
    datagen_finger = ImageDataGenerator(horizontal_flip=False,height_shift_range=0.1,width_shift_range=0.1,shear_range=0.5,rotation_range=20)
    datagen_finger.fit(data_aug_finger)
    
    it_face = datagen_face.flow(data_aug_face,batch_size=num_data_augment,shuffle=False)
    it_finger = datagen_finger.flow(data_aug_finger,batch_size=num_data_augment,shuffle=False)
    
    results_face = np.concatenate([data_face,it_face.next()],axis=0)
    results_finger = np.concatenate([data_finger,it_finger.next()],axis=0)
    #np.random.shuffle(results)
    return results_face,results_finger

def augment_data_only_fingerprint(data_finger, num_augment_percent = 0.5):
    
    num_data = data_finger.shape[0]
    num_data_augment = int(data_finger.shape[0]*num_augment_percent)
    index_batch = np.arange(0,num_data_augment)
    index_data = np.random.choice(num_data,num_data_augment,replace=False)
    ###############################
    data_aug_finger = []
    for i in index_data:
        data_aug_finger.append(data_finger[i,:,:,:])   
    data_aug_finger = np.asarray(data_aug_finger) 
    ################################
    datagen_finger = ImageDataGenerator(horizontal_flip=False,height_shift_range=0.2,width_shift_range=0.2,shear_range=0.1,rotation_range=20)
    datagen_finger.fit(data_aug_finger)
    ###############################
    it_finger = datagen_finger.flow(data_aug_finger,batch_size=num_data_augment,shuffle=False)
    finger_augmented = it_finger.next()
    ################################
    result_finger = data_finger
    for index in zip(index_data,index_batch):
        result_finger[index[0],:,:,:]=finger_augmented[index[1],:,:,:]
    return result_finger


In [25]:
data_train_face, data_train_finger = augment_data_face_fingerprint(data_train_face,data_train_finger)

In [26]:
data_train_face.shape,data_train_finger.shape,data_val_face.shape,data_val_finger.shape

((4284, 256, 256, 3),
 (4284, 256, 256, 1),
 (715, 256, 256, 3),
 (715, 256, 256, 1))

In [27]:
data_train_finger = ((data_train_finger)/np.max(data_train_finger))
data_val_finger = ((data_val_finger)/np.max(data_val_finger))

In [28]:
data_train_finger = np.where(data_train_finger > .5, 1.0, 0.0).astype('float32')
data_val_finger = np.where(data_val_finger > .5, 1.0, 0.0).astype('float32')

In [29]:
data_train_finger.shape, data_val_finger.shape

((4284, 256, 256, 1), (715, 256, 256, 1))

In [30]:
data_train_face = ((data_train_face-127.5)/127.5).astype('float32')

In [31]:
data_train_face.shape

(4284, 256, 256, 3)

In [32]:
data_val_face = ((data_val_face-127.5)/127.5).astype('float32')

In [33]:
data_val_face.shape

(715, 256, 256, 3)

In [34]:
print(np.max(data_train_finger),np.median(data_train_finger),np.min(data_train_finger))
print(np.max(data_train_face),np.median(data_train_face),np.min(data_train_face))

1.0 0.0 0.0
1.0 0.07450981 -1.0


In [35]:
train_face = data_train_face
val_face = data_val_face
train_finger = data_train_finger
val_finger = data_val_finger

In [36]:
# from sklearn.model_selection import train_test_split
# train_latentcode,valid_latentcode,train_face,valid_face = train_test_split(data_train_latentcode[2],
#                                                              data_train_face,
#                                                              test_size=0.2,
#                                                              random_state=13)

In [37]:
train_face.shape,val_face.shape,train_finger.shape, val_finger.shape

((4284, 256, 256, 3),
 (715, 256, 256, 3),
 (4284, 256, 256, 1),
 (715, 256, 256, 1))

In [1]:
plt.subplot(1, 2, 1)
plt.imshow(data_train_finger[10],cmap='gray')
plt.subplot(1, 2, 2)
plt.imshow((data_train_face[10]+1)*0.5)
plt.show()

In [12]:
DATA_DIR = './data_train/data_get_latent/'
with open(os.path.join(DATA_DIR,"data_face_2856_train.pkl"), "rb") as input_file:
    data_train_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_2856_train.pkl"), "rb") as input_file:
    data_train_finger = pickle.load(input_file)

    
with open(os.path.join(DATA_DIR,"data_face_2856_val.pkl"), "rb") as input_file:
    data_val_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_2856_val.pkl"), "rb") as input_file:
    data_val_finger = pickle.load(input_file)

In [13]:
train_face = data_train_face
val_face = data_val_face
train_finger = data_train_finger
val_finger = data_val_finger

# Training GAN

In [None]:
GAN.train(
    train_finger  
    , train_face
    , val_finger
    , val_face
    , batch_size = 16
    , epochs = 20000
    , run_folder = './model_save'
    , print_every_n_batches = 10
    , n_critic = 2
    , using_generator = True
)

0 [D loss: (-1976.351)(R -3934.305, F -18.398)]  [Gan loss: 293.306, Per loss: 293.306,W: 279.281 ]
0 [VAE loss: -0.888 (Re -0.874 Mc -0.015) Full 13.167 Per 14.062 Re -0.874 Mc -0.022 ]


In [59]:
GAN.encoder_model.save('./model_save/encoder_17200_1379.h5')
GAN.discriminator.save('./model_save/discriminator_17200_1379.h5')
GAN.generator.save('./model_save/generator_17200_1379.h5')

# Validation

In [19]:
gen = tf.keras.models.load_model('./model_save/model/generator_3channel_conditional.h5',custom_objects={'get_perceptual_loss': get_perceptual_loss})

In [102]:
def test(validate=True):    # Test
    r, c = 8, 4
    if validate :
        y_true = val_face
        x_true = val_finger
    else:
        y_true = train_face[900:1000,:,:,:]
        x_true = train_finger[900:1000,:,:,:]
    
    latent_code = GAN.encoder_model.predict(x_true)
    gen_imgs = GAN.generator.predict(latent_code[0])
    gen_finger = GAN.decoder_model.predict(latent_code[0])
    # Perceptual loss
    #perceptloss = get_perceptual_loss(y_true,gen_imgs)

    indx = np.random.choice(y_true.shape[0], int(0.25*c*r) ,replace=False)
    
    finger_real = x_true[indx]
    
    face_real = 0.5*(y_true[indx]+1)

    gen_imgs = 0.5 * (gen_imgs[indx] + 1)
    gen_imgs = np.clip(gen_imgs, 0, 1)
 
    gen_finger = gen_finger[indx]
    print(np.max(gen_finger), np.min(gen_finger),np.median(gen_finger))
    fig, axs = plt.subplots(r, c, figsize=(15,30))
    #fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
    cnt = 0
    for i in range(r):
        for j in range(int(0.25*c)):
            axs[i,4*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
            axs[i,4*j].axis('off')
            axs[i,4*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
            axs[i,4*j+1].axis('off')
            axs[i,4*j+2].imshow(np.squeeze(finger_real[cnt, :,:,:]),cmap ="gray")
            axs[i,4*j+2].axis('off')
            axs[i,4*j+3].imshow(np.squeeze(gen_finger[cnt, :,:,:]),cmap ="gray")
            axs[i,4*j+3].axis('off')
            cnt += 1

    plt.show()
    #fig.savefig(os.path.join('./model_save', "images/face_gens.png" ))
    plt.close()

In [2]:
test(validate=True)

In [53]:
#GAN.sample_images(train_latentcode,train_face,"./model_save")

In [38]:
DATA_DIR = './data_train/data_split/'
with open(os.path.join(DATA_DIR,"latentcode_val_augment.pkl"), "rb") as input_file:
    latentcode_augment = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"face_val_augment.pkl"), "rb") as input_file:
    face_augment = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"finger_val_augment.pkl"), "rb") as input_file:
    finger_augment = pickle.load(input_file)

In [39]:
latentcode_augment[2].shape,face_augment.shape,finger_augment.shape

((300, 128), (300, 128, 128, 3), (300, 128, 128, 1))

In [40]:
finger_augment = ((finger_augment)/np.max(finger_augment))
finger_augment = np.where(finger_augment > .5, 1.0, -1.0).astype('float32')
#finger_augment = ((finger_augment-127.5)/127.5).astype('float32')

In [3]:
r, c = 8, 4

latentcode_datatrain = train_latentcode[200:500,:]
y_true_datatrain = train_face[200:500,:,:,:]
label_true_datatrain = label_train_finger[200:500,:,:,:]
print(latentcode_datatrain.shape)
noise_gauss = tf.keras.backend.random_normal(shape=(300, 128))
gen_imgs_augment = GAN.generator.predict([latentcode_augment[2]+noise_gauss*0.2,finger_augment])

# Perceptual loss
perceptloss_augment = get_perceptual_loss(y_true_datatrain,gen_imgs_augment)

indx = np.random.choice(y_true_datatrain.shape[0], int(0.5*c*r) ,replace=False)

face_real = 0.5*(y_true_datatrain[indx]+1)
face_real = face_real[:,:,:,[2,1,0]]

gen_imgs = 0.5 * (gen_imgs_augment[indx] + 1)
gen_imgs = np.clip(gen_imgs, 0, 1)
gen_imgs = gen_imgs[:,:,:,[2,1,0]]


fig, axs = plt.subplots(r, c, figsize=(15,30))
#fig.suptitle("Perceptual loss : %.3f" %(perceptloss_augment))
cnt = 0
for i in range(r):
    for j in range(int(0.25*c)):
        axs[i,4*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
        axs[i,4*j].axis('off')
        axs[i,4*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
        axs[i,4*j+1].axis('off')
        axs[i,4*j+2].imshow(np.squeeze(finger_augment[indx][cnt, :,:,:]),cmap='gray')
        axs[i,4*j+2].axis('off')
        axs[i,4*j+3].imshow(np.squeeze(label_true_datatrain[indx][cnt, :,:,:]),cmap='gray')
        axs[i,4*j+3].axis('off')
    cnt += 1

plt.show()
fig.savefig(os.path.join('./model_save', "images/face_gens_1.png" ))
plt.close()

### 