In [1]:
import sys
import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt

In [2]:
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from skimage.transform import resize
from tensorflow.keras.optimizers import Adam, RMSprop,SGD
from tensorflow.keras.initializers import RandomNormal
from functools import partial

In [3]:
from tensorflow.keras.applications.vgg19 import VGG19

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# Build Perceptual loss

In [5]:
from tensorflow.keras.applications.vgg19 import VGG19
class Cut_VGG19:
    """
    Class object that fetches keras' VGG19 model trained on the imagenet dataset
    and declares <layers_to_extract> as output layers. Used as feature extractor
    for the perceptual loss function.
    Args:
        layers_to_extract: list of layers to be declared as output layers.
        patch_size: integer, defines the size of the input (patch_size x patch_size).
    Attributes:
        loss_model: multi-output vgg architecture with <layers_to_extract> as output layers.
    """
    
    def __init__(self, patch_size, layers_to_extract):
        self.patch_size = patch_size
        self.input_shape = (patch_size,) * 2 + (3,)
        self.layers_to_extract = layers_to_extract
        
        if len(self.layers_to_extract) > 0:
            self._cut_vgg()
    
    def _cut_vgg(self):
        """
        Loads pre-trained VGG, declares as output the intermediate
        layers selected by self.layers_to_extract.
        """
        
        vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.input_shape)
        vgg.trainable = False
        outputs = [vgg.layers[i].output for i in self.layers_to_extract]
        self.model = Model([vgg.input], outputs)

        
feature_extraction = Cut_VGG19(128,[5,9])   
feature_extraction.model.trainable = False
def get_perceptual_loss(y_true,y_pred):
    content_feature = feature_extraction.model(y_true)
    new_feature = feature_extraction.model(y_pred)
    perceptual_loss = 0
    weight = tf.constant([1/16,1/8], dtype = tf.float32)
    for i in range(len(new_feature)):
        perceptual_loss += weight[i]*K.mean(K.square(new_feature[i] - content_feature[i]))
    l2_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(y_true,y_pred))
    total_loss = perceptual_loss + 15*l2_loss
    return total_loss

In [6]:
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [7]:
class WGANGP():
    def __init__(self):
        self.input_dim = (128,128,3)
        self.critic_conv_filters = [16,32,64,128]
        self.critic_conv_kernel_size = [3,3,3,3]
        self.critic_conv_strides = [2,2,2,2]
        self.critic_batch_norm_momentum = None
        self.critic_activation = 'leaky_relu'
        self.critic_dropout_rate = None
        self.critic_learning_rate = 1e-3
        

        self.generator_initial_dense_layer_size = (4,4, 128)
        self.generator_upsample = [1,1,1,1,1]
        self.generator_conv_filters = [128,64,32,16,3]
        self.generator_conv_kernel_size = [3,3,3,3,3]
        self.generator_conv_strides = [2,2,2,2,2]
        self.generator_batch_norm_momentum =  0.8
        self.generator_activation = 'leaky_relu'
        self.generator_dropout_rate = None
        self.generator_learning_rate = 1e-3
        
        self.optimiser = 'rmsprop'

        self.z_dim = 128

        self.n_layers_critic = len(self.critic_conv_filters)
        self.n_layers_generator = len(self.generator_conv_filters)

        self.weight_init = RandomNormal(mean=0., stddev=0.02) # 'he_normal' #RandomNormal(mean=0., stddev=0.02)
        self.grad_weight = 10
        self.batch_size = 128


        self.d_losses = []
        self.g_losses = []
        self.epoch = 0

        self._build_critic()
        self._build_generator()

        self._build_adversarial()
        
#     def gradient_penalty_loss(self, y_true, y_pred, interpolated_samples):
#         gradients = K.gradients(y_pred, interpolated_samples)[0]

#         # compute the euclidean norm by squaring ...
#         gradients_sqr = K.square(gradients)
#         #   ... summing over the rows ...
#         gradients_sqr_sum = K.sum(gradients_sqr,
#                                   axis=np.arange(1, len(gradients_sqr.shape)))
#         #   ... and sqrt
#         gradient_l2_norm = K.sqrt(gradients_sqr_sum)
#         # compute lambda * (1 - ||grad||)^2 still for each single sample
#         gradient_penalty = K.square(1 - gradient_l2_norm)
#         # return the mean as loss over all the batch samples
#         return K.mean(gradient_penalty)

    def gradient_penalty_loss(self,y_true, y_pred, discriminator,gradient_penalty_weight):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
      

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(y_pred)
            # 1. Get the discriminator output for this interpolated image.
            pred = discriminator(y_pred, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [y_pred])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp*gradient_penalty_weight
    
    
    def wasserstein(self, y_true, y_pred):
        return -K.mean(y_true * y_pred)
    
    def get_activation(self, activation):
        if activation == 'leaky_relu':
            layer = LeakyReLU(alpha = 0.2)
        else:
            layer = Activation(activation)
        return layer
    
    def _build_critic(self):
        critic_input = Input(shape=self.input_dim, name='critic_input')
        x = critic_input

        for i in range(self.n_layers_critic):
            x = Conv2D(
                filters = self.critic_conv_filters[i]
                , kernel_size = self.critic_conv_kernel_size[i]
                , strides = self.critic_conv_strides[i]
                , padding = 'same'
                , name = 'critic_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)

            if self.critic_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.critic_batch_norm_momentum)(x)
            x = self.get_activation(self.critic_activation)(x)
            if self.critic_dropout_rate:
                x = Dropout(rate = self.critic_dropout_rate)(x)
        x = Flatten()(x)
        
        critic_output = Dense(1, activation=None
        , kernel_initializer = self.weight_init
        )(x)

        self.critic = Model(critic_input, critic_output)
        
    def _build_generator(self):

        ### THE generator

        generator_input = Input(shape=(self.z_dim,), name='generator_input')

        x = generator_input

        x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)
        
        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
        
        x = self.get_activation(self.generator_activation)(x)

        x = Reshape(self.generator_initial_dense_layer_size)(x)

        if self.generator_dropout_rate:
            x = Dropout(rate = self.generator_dropout_rate)(x)

        for i in range(self.n_layers_generator):

            if self.generator_upsample[i] == 2:
                x = UpSampling2D()(x)
                x = Conv2D(
                filters = self.generator_conv_filters[i]
                , kernel_size = self.generator_conv_kernel_size[i]
                , padding = 'same'
                , name = 'generator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)
            else:

                x = Conv2DTranspose(
                    filters = self.generator_conv_filters[i]
                    , kernel_size = self.generator_conv_kernel_size[i]
                    , padding = 'same'
                    , strides = self.generator_conv_strides[i]
                    , name = 'generator_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                    )(x)

            if i < self.n_layers_generator - 1:

                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)

                x = self.get_activation(self.generator_activation)(x)
                
            else:
                x = Activation('tanh')(x)

        generator_output = x
        self.generator = Model(generator_input, generator_output)




    def get_opti(self, lr):
        if self.optimiser == 'adam':
            opti = Adam(lr=lr, beta_1=0.5)
        elif self.optimiser == 'rmsprop':
            opti = RMSprop(lr=lr)
        else:
            opti = Adam(lr=lr)

        return opti


    def set_trainable(self, m, val):
        m.trainable = val
        for l in m.layers:
            l.trainable = val

    def _build_adversarial(self):
                
        self.critic.compile(
            optimizer=self.get_opti(self.critic_learning_rate) 
            , loss = self.wasserstein
        )
        
        ### COMPILE THE FULL GAN

        self.set_trainable(self.critic, False)

        model_input = Input(shape=(self.z_dim,), name='model_input')
        model_output = self.critic(self.generator(model_input))
        self.model = Model(model_input, model_output)
        
        self.model.compile(
            optimizer=self.get_opti(self.generator_learning_rate)
            , loss=self.wasserstein
            )
        
        
        self.generator.compile(
            optimizer=self.get_opti(self.generator_learning_rate)
            , loss=get_perceptual_loss
            )

        self.set_trainable(self.critic, True)
        
    def train_critic(self, x_train,Y_train, batch_size, using_generator):

        valid = np.ones((batch_size,1))
        fake = -np.ones((batch_size,1))

        if using_generator:
            true_imgs = next(Y_train)
            if true_imgs.shape[0] != batch_size:
                true_imgs = next(Y_train)
        else:
            idx = np.random.randint(0, Y_train.shape[0], batch_size)
            true_imgs = Y_train[idx]
        
        
        #noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        noise = next(x_train)
        gen_imgs = self.generator.predict(noise)
        
        g_perceptual = self.generator.train_on_batch(noise,true_imgs)
        
        d_loss_real =   self.critic.train_on_batch(true_imgs, valid)
        d_loss_fake =   self.critic.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)

        for l in self.critic.layers:
            weights = l.get_weights()
            weights = [np.clip(w, -0.01, 0.01) for w in weights]
            l.set_weights(weights)

        for l in self.critic.layers:
        
            weights = l.get_weights()
            if 'batch_normalization' in l.get_config()['name']:
                pass
                # weights = [np.clip(w, -0.01, 0.01) for w in weights[:2]] + weights[2:]
            else:
                weights = [np.clip(w, -0.01, 0.01) for w in weights]
            
            l.set_weights(weights)

        return [d_loss, d_loss_real, d_loss_fake, g_perceptual]

    def train_generator(self,x_train, batch_size):
        valid = np.ones((batch_size,1), dtype=np.float32)
        noise = next(x_train)
        #noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        return self.model.train_on_batch(noise, valid)

    def load_random_batch(self,X_train,Y_train,batch_size):
        num_image = X_train.shape[0]
        random_samples_indices = np.random.choice(num_image, batch_size,replace=False)
        X = []
        Y = []
        for i in random_samples_indices:
            X.append(X_train[i])
            Y.append(Y_train[i])
        X = iter(np.asarray(X))
        Y = iter(np.asarray(Y))
        return X,Y
    
    def shuffle_data_batch(self,array_X,array_Y,batch_size):
        indices = np.arange(array_X.shape[0])
        np.random.shuffle(indices)
        array_X = array_X[indices]
        array_Y = array_Y[indices]
        def split_into_chunks(l, n):
            for i in range(0, l.shape[0], n):
                yield l[i:i + n]  
        array_X = split_into_chunks(array_X,batch_size)
        array_Y = split_into_chunks(array_Y,batch_size)
        return array_X,array_Y

    def train(self, x_train,Y_train,x_val,Y_val, batch_size, epochs, run_folder, print_every_n_batches = 10, n_critic = 5,using_generator = False):
        self.batch_size = batch_size
        
        for epoch in range(self.epoch, self.epoch + epochs):
            x_train_wgan,Y_train_wgan = self.shuffle_data_batch(x_train,Y_train,batch_size)
            for _ in range(n_critic):
                d_loss = self.train_critic(x_train_wgan,Y_train_wgan, batch_size, using_generator)

            g_loss = self.train_generator(x_train_wgan,batch_size)
               
            # Plot the progress
            
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)

            # If at save interval => save generated image samples
            if epoch % print_every_n_batches == 0:
                print ("%d [D loss: (%.3f)(R %.3f, F %.3f) Per %.3f]  [G loss: %.3f] " % (epoch, d_loss[0], d_loss[1], d_loss[2],d_loss[3], g_loss))
                self.sample_images(x_val,Y_val,run_folder)
                #self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
                #self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                #self.save_model(run_folder)
            
            self.epoch+=1



    def sample_images(self,x_val,Y_val, run_folder):
        # Test
        r, c = 4, 4

        latent_code = x_val[:100,:]
        y_true = Y_val[:100,:,:,:]

        gen_imgs = self.generator.predict(latent_code)
        # Perceptual loss
        perceptloss = get_perceptual_loss(y_true,gen_imgs)

        indx = np.random.choice(y_true.shape[0], int(0.5*c*r) ,replace=False)

        face_real = 0.5*(y_true[indx]+1)
        face_real = face_real[:,:,:,[2,1,0]]

        gen_imgs = 0.5 * (gen_imgs[indx] + 1)
        gen_imgs = np.clip(gen_imgs, 0, 1)
        gen_imgs = gen_imgs[:,:,:,[2,1,0]]


        fig, axs = plt.subplots(r, c, figsize=(15,15))
        fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
        cnt = 0
        for i in range(r):
            for j in range(int(0.5*c)):
                axs[i,2*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
                axs[i,2*j].axis('off')
                axs[i,2*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
                axs[i,2*j+1].axis('off')
                cnt += 1
        fig.savefig(os.path.join(run_folder, "images_latent/sample_%d.png" % self.epoch))
        plt.close()
    

In [8]:
GAN = WGANGP()

In [9]:
GAN.critic.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
critic_input (InputLayer)    [(None, 128, 128, 3)]     0         
_________________________________________________________________
critic_conv_0 (Conv2D)       (None, 64, 64, 16)        448       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 16)        0         
_________________________________________________________________
critic_conv_1 (Conv2D)       (None, 32, 32, 32)        4640      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 32)        0         
_________________________________________________________________
critic_conv_2 (Conv2D)       (None, 16, 16, 64)        18496     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 64)       

In [10]:
GAN.generator.summary()

Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator_input (InputLayer) [(None, 128)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 2048)              264192    
_________________________________________________________________
batch_normalization (BatchNo (None, 2048)              8192      
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 2048)              0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 128)         0         
_________________________________________________________________
generator_conv_0 (Conv2DTran (None, 8, 8, 128)         147584    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 128)        

# Load data train

In [11]:
DATA_DIR = './data_train/data_split/'
with open(os.path.join(DATA_DIR,"data_train_face_train_3channels.pkl"), "rb") as input_file:
    data_train_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"latentcode_finger_train_3channels.pkl"), "rb") as input_file:
    data_train_latentcode = pickle.load(input_file)
    
with open(os.path.join(DATA_DIR,"data_train_face_val_3channels.pkl"), "rb") as input_file:
    data_val_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"latentcode_finger_val_3channels.pkl"), "rb") as input_file:
    data_val_latentcode = pickle.load(input_file)
    

In [12]:
data_train_face = ((data_train_face-127.5)/127.5).astype('float32')

In [13]:
data_train_face.shape,data_train_latentcode[2].shape

((1802, 128, 128, 3), (1802, 128))

In [14]:
data_val_face = ((data_val_face-127.5)/127.5).astype('float32')

In [15]:
data_val_face.shape,data_val_latentcode[2].shape

((150, 128, 128, 3), (150, 128))

In [16]:
train_latentcode = data_train_latentcode[2]
train_face = data_train_face
valid_latentcode = data_val_latentcode[2]
valid_face = data_val_face

In [17]:
# from sklearn.model_selection import train_test_split
# train_latentcode,valid_latentcode,train_face,valid_face = train_test_split(data_train_latentcode[2],
#                                                              data_train_face,
#                                                              test_size=0.2,
#                                                              random_state=13)

In [18]:
train_latentcode.shape,valid_latentcode.shape,train_face.shape,valid_face.shape

((1802, 128), (150, 128), (1802, 128, 128, 3), (150, 128, 128, 3))

# Training GAN

In [1]:
GAN.train(  
      train_latentcode
    , train_face
    , valid_latentcode
    , valid_face
    , batch_size = 256
    , epochs = 10000
    , run_folder = './model_save'
    , print_every_n_batches = 50
    , n_critic = 5
    , using_generator = True
)

In [29]:
# GAN.critic.save('./model_save/model/discriminator_3channel_068.h5')
# GAN.generator.save('./model_save/model/generator_3channel_068.h5')

# Validation

In [11]:
gen = tf.keras.models.load_model('./model_save/model/generator_3channel_068.h5',custom_objects={'get_perceptual_loss': get_perceptual_loss})

In [56]:
DATA_DIR = './data_train/data_split/'
with open(os.path.join(DATA_DIR,"latentcode_finger_longthieu.pkl"), "rb") as input_file:
    latent_longthieu = pickle.load(input_file)

In [57]:
latent_longthieu[2].shape

(4, 128)

In [2]:
# Test
r, c = 2, 2

latent_code = latent_longthieu[2]

gen_imgs = GAN.generator.predict(latent_code)


gen_imgs = 0.5 * (gen_imgs + 1)
gen_imgs = np.clip(gen_imgs, 0, 1)
gen_imgs = gen_imgs[:,:,:,[2,1,0]]


fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0
for i in range(r):
    for j in range(int(c)):
        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
        #axs[i,2*j].axis('off')
        cnt += 1
# plt.imshow(gen_imgs[1])
# plt.show()
# #fig.savefig(os.path.join('./model_save', "images/face_gens.png" ))
# plt.close()

In [3]:
# # Test
# r, c = 4, 4

# latent_code = valid_latentcode[:100,:]
# y_true = valid_face[:100,:,:,:]

# gen_imgs = GAN.generator.predict(latent_code)
# # Perceptual loss
# perceptloss = get_perceptual_loss(y_true,gen_imgs)

# indx = np.random.choice(y_true.shape[0], int(0.5*c*r) ,replace=False)

# face_real = 0.5*(y_true[indx]+1)
# face_real = face_real[:,:,:,[2,1,0]]

# gen_imgs = 0.5 * (gen_imgs[indx] + 1)
# gen_imgs = np.clip(gen_imgs, 0, 1)
# gen_imgs = gen_imgs[:,:,:,[2,1,0]]


# fig, axs = plt.subplots(r, c, figsize=(15,15))
# fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
# cnt = 0
# for i in range(r):
#     for j in range(int(0.5*c)):
#         axs[i,2*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
#         #axs[i,2*j].axis('off')
#         axs[i,2*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
#         #axs[i,2*j+1].axis('off')
#         cnt += 1

# plt.show()
# #fig.savefig(os.path.join('./model_save', "images/face_gens.png" ))
# plt.close()

In [29]:
GAN.sample_images(train_latentcode,train_face,"./model_save")

### 