In [3]:
import sys
import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt

In [2]:
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.layers import Input, Concatenate, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from skimage.transform import resize
from tensorflow.keras.optimizers import Adam, RMSprop,SGD
from tensorflow.keras.initializers import RandomNormal
from functools import partial

In [3]:
from tensorflow.keras.applications.vgg19 import VGG19

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# Build Perceptual loss

In [5]:
from tensorflow.keras.applications.vgg19 import VGG19
class Cut_VGG19:
    """
    Class object that fetches keras' VGG19 model trained on the imagenet dataset
    and declares <layers_to_extract> as output layers. Used as feature extractor
    for the perceptual loss function.
    Args:
        layers_to_extract: list of layers to be declared as output layers.
        patch_size: integer, defines the size of the input (patch_size x patch_size).
    Attributes:
        loss_model: multi-output vgg architecture with <layers_to_extract> as output layers.
    """
    
    def __init__(self, patch_size, layers_to_extract):
        self.patch_size = patch_size
        self.input_shape = (patch_size,) * 2 + (3,)
        self.layers_to_extract = layers_to_extract
        
        if len(self.layers_to_extract) > 0:
            self._cut_vgg()
    
    def _cut_vgg(self):
        """
        Loads pre-trained VGG, declares as output the intermediate
        layers selected by self.layers_to_extract.
        """
        
        vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.input_shape)
        vgg.trainable = False
        outputs = [vgg.layers[i].output for i in self.layers_to_extract]
        self.model = Model([vgg.input], outputs)

        
feature_extraction = Cut_VGG19(128,[5,9])   
feature_extraction.model.trainable = False


In [6]:
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [1]:
class WGANGP():
    def __init__(self):
        self.input_dim = (128,128,3)
        self.optimiser = 'rmsprop'
        self.z_dim = 128
        ########################################
        self.conditional_input = (128,128,1)
        self.conditional_conv_filters = [2,4,16,32,64]
        self.conditional_conv_kernel_size = [3,3,3,3,3]
        self.conditional_conv_strides = [2,2,2,2,2]
        #########################################
        self.generator_initial_dense_layer_size = (4,4, 128)
        self.generator_upsample = [1,1,1,1,1]
        self.generator_conv_filters = [128,64,32,16,3]
        self.generator_conv_kernel_size = [3,3,3,3,3]
        self.generator_conv_strides = [2,2,2,2,2]
        self.generator_batch_norm_momentum =  0.8
        self.generator_activation = 'leaky_relu'
        self.generator_dropout_rate = None
        self.generator_learning_rate = 1e-3
        ###########################################
        self.discriminator_conv_filters = [16,32,64,128,256]
        self.discriminator_conv_kernel_size = [3,3,3,3,3]
        self.discriminator_conv_strides = [2,2,2,2,2]
        self.discriminator_batch_norm_momentum = None
        self.discriminator_activation = 'leaky_relu'
        self.discriminator_dropout_rate = None
        self.discriminator_learning_rate = 1e-3
        ###########################################
        self.weight_init = RandomNormal(mean=0., stddev=0.02)
        self.grad_weight = 10
        self.batch_size = 128
        
        self.n_layers_discriminator = len(self.discriminator_conv_filters)
        self.n_layers_generator = len(self.generator_conv_filters)
        self.n_layers_conditional = len(self.conditional_conv_filters)
        ###############################################                               
        self.d_losses = []
        self.g_losses = []
        self.epoch = 0
        self._build_generator()
        self.generator.summary()
        self._build_discriminator()
        self.discriminator.summary()
        print("#############################################")
        self._build_adversarial()
        self.model.summary()
    ####################### Loss ###########################
    def wasserstein(self, y_true, y_pred):
        return -K.mean(y_true * y_pred)
    def get_perceptual_loss(self,y_true,y_pred):
        content_feature = feature_extraction.model(y_true)
        new_feature = feature_extraction.model(y_pred)
        perceptual_loss = 0
        weight = tf.constant([1/16,1/8], dtype = tf.float32)
        for i in range(len(new_feature)):
            perceptual_loss += weight[i]*K.mean(K.square(new_feature[i] - content_feature[i]))
        l2_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(y_true,y_pred))
        total_loss = perceptual_loss + 150*l2_loss
        return total_loss

    
    ################# Activation layer #####################                                                                
    def get_activation(self, activation):
        if activation == 'leaky_relu':
            layer = LeakyReLU(alpha = 0.2)
        else:
            layer = Activation(activation)
        return layer
    
    ####################################################################
    #################### Build Discriminator Model #####################
    ####################################################################
    def _build_discriminator(self):
        discriminator_input = Input(shape=self.input_dim, name='discriminator_input')
        target = Input(shape= self.conditional_input, name='target_image')
        x = Concatenate()([discriminator_input,target])
        #x = critic_input

        for i in range(self.n_layers_discriminator):
            x = Conv2D(
                filters = self.discriminator_conv_filters[i]
                , kernel_size = self.discriminator_conv_kernel_size[i]
                , strides = self.discriminator_conv_strides[i]
                , padding = 'same'
                , name = 'discriminator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)

            if self.discriminator_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)
            x = self.get_activation(self.discriminator_activation)(x)
            if self.discriminator_dropout_rate:
                x = Dropout(rate = self.discriminator_dropout_rate)(x)
        x = Flatten()(x)
        
        discriminator_output = Dense(1, activation=None
        , kernel_initializer = self.weight_init
        )(x)

        self.discriminator = Model([discriminator_input, target], discriminator_output)
        
    ####################################################################
    #################### Build Generator Model #########################
    ####################################################################
    
    def _build_generator(self):
        ############  generator ###############
        generator_input = Input(shape=(self.z_dim,), name='generator_input')
        x = generator_input
        x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)
        
        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)       
        x = self.get_activation(self.generator_activation)(x)
        x = Reshape(self.generator_initial_dense_layer_size)(x)
        ########### Conditional #############
        conditional_input = Input(shape=self.conditional_input, name='conditional_input')
        x_label = conditional_input
        for i in range(self.n_layers_conditional):
            x_label = Conv2D(
                filters = self.conditional_conv_filters[i]
                , kernel_size = self.conditional_conv_filters[i]
                , strides = self.conditional_conv_strides[i]
                , padding = 'same'
                , name = 'conditional_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x_label)
            x_label = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x_label)
            x_label = self.get_activation(self.generator_activation)(x_label)
        ########### Concatenate ############
        x = Concatenate()([x, x_label])  
        
        if self.generator_dropout_rate:
            x = Dropout(rate = self.generator_dropout_rate)(x)

        for i in range(self.n_layers_generator):

            if self.generator_upsample[i] == 2:
                x = UpSampling2D()(x)
                x = Conv2D(
                filters = self.generator_conv_filters[i]
                , kernel_size = self.generator_conv_kernel_size[i]
                , padding = 'same'
                , name = 'generator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)
            else:

                x = Conv2DTranspose(
                    filters = self.generator_conv_filters[i]
                    , kernel_size = self.generator_conv_kernel_size[i]
                    , padding = 'same'
                    , strides = self.generator_conv_strides[i]
                    , name = 'generator_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                    )(x)

            if i < self.n_layers_generator - 1:

                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)

                x = self.get_activation(self.generator_activation)(x)
                
            else:
                x = Activation('tanh')(x)

        generator_output = x
        self.generator = Model([generator_input,conditional_input], generator_output)




    def get_opti(self, lr):
        if self.optimiser == 'adam':
            opti = Adam(lr=lr, beta_1=0.5)
        elif self.optimiser == 'rmsprop':
            opti = RMSprop(lr=lr)
        else:
            opti = Adam(lr=lr)

        return opti


    def set_trainable(self, m, val):
        m.trainable = val
        for l in m.layers:
            l.trainable = val

    def _build_adversarial(self):
                
        self.discriminator.compile(
            optimizer=self.get_opti(self.discriminator_learning_rate) 
            , loss = self.wasserstein
        )
        
        ### COMPILE THE FULL GAN

        self.set_trainable(self.discriminator, False)
        
        gen_noise_input, gen_label_input = self.generator.input
        gen_fake_image_output = self.generator.output

        disc_output = self.discriminator([gen_fake_image_output, gen_label_input])
        
        self.model = Model([gen_noise_input, gen_label_input], [disc_output,gen_fake_image_output])
        
        self.model.compile(
            optimizer = self.get_opti(self.generator_learning_rate)
            , loss=[self.wasserstein,self.get_perceptual_loss]
            )
        
        
#         self.generator.compile(
#             optimizer=self.get_opti(self.generator_learning_rate)
#             , loss=get_perceptual_loss
#             )

        self.set_trainable(self.discriminator, True)
        
    def train_discriminator(self, x_train,Y_train,label_train, batch_size, using_generator):

        valid = np.ones((batch_size,1))
        fake = -np.ones((batch_size,1))

        if using_generator:
            true_imgs = next(Y_train)
            if true_imgs.shape[0] != batch_size:
                true_imgs = next(Y_train)
        else:
            idx = np.random.randint(0, Y_train.shape[0], batch_size)
            true_imgs = Y_train[idx]
        
        
        #noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        noise = next(x_train)
        label = next(label_train)
        
        gen_imgs = self.generator.predict([noise,label])
        
        #g_perceptual = self.generator.train_on_batch([noise,label],true_imgs)
        
        d_loss_real =   self.discriminator.train_on_batch([true_imgs,label], valid)
        d_loss_fake =   self.discriminator.train_on_batch([gen_imgs,label], fake)
        d_loss = 0.5 * (d_loss_real + d_loss_fake)

        for l in self.discriminator.layers:
            weights = l.get_weights()
            weights = [np.clip(w, -0.01, 0.01) for w in weights]
            l.set_weights(weights)

        for l in self.discriminator.layers:
        
            weights = l.get_weights()
            if 'batch_normalization' in l.get_config()['name']:
                pass
                # weights = [np.clip(w, -0.01, 0.01) for w in weights[:2]] + weights[2:]
            else:
                weights = [np.clip(w, -0.01, 0.01) for w in weights]
            
            l.set_weights(weights)

        return [d_loss, d_loss_real, d_loss_fake]

    def train_generator(self,x_train,Y_train,label_train, batch_size):
        valid = np.ones((batch_size,1), dtype=np.float32)
        noise = next(x_train)
        label = next(label_train)
        true_images = next(Y_train)
        #noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        return self.model.train_on_batch([noise,label], [valid,true_images])

    def load_random_batch(self,X_train,Y_train,batch_size):
        num_image = X_train.shape[0]
        random_samples_indices = np.random.choice(num_image, batch_size,replace=False)
        X = []
        Y = []
        for i in random_samples_indices:
            X.append(X_train[i])
            Y.append(Y_train[i])
        X = iter(np.asarray(X))
        Y = iter(np.asarray(Y))
        return X,Y
    
    def shuffle_data_batch(self,array_X,array_Y,array_label,batch_size):
        indices = np.arange(array_X.shape[0])
        np.random.shuffle(indices)
        array_X = array_X[indices]
        array_Y = array_Y[indices]
        array_label = array_label[indices]
        def split_into_chunks(l, n):
            for i in range(0, l.shape[0], n):
                yield l[i:i + n]  
        array_X = split_into_chunks(array_X,batch_size)
        array_Y = split_into_chunks(array_Y,batch_size)
        array_label = split_into_chunks(array_label,batch_size)
        return array_X,array_Y,array_label

    def train(self, x_train,Y_train,x_val,Y_val,label_train,label_val, batch_size, epochs, run_folder, print_every_n_batches = 10, n_critic = 5,using_generator = False):
        self.batch_size = batch_size
        
        for epoch in range(self.epoch, self.epoch + epochs):
            x_train_gan,Y_train_gan,label_train_gan = self.shuffle_data_batch(x_train,Y_train,label_train,batch_size)
            for _ in range(n_critic):
                d_loss = self.train_discriminator(x_train_gan,Y_train_gan,label_train_gan, batch_size, using_generator)

            g_loss = self.train_generator(x_train_gan,Y_train_gan,label_train_gan, batch_size)
               
            # Plot the progress
            
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)

            # If at save interval => save generated image samples
            if epoch % print_every_n_batches == 0:
                print ("%d [D loss: (%.3f)(R %.3f, F %.3f)]  [W loss: %.3f, G loss: %.3f] " % (epoch, d_loss[0], d_loss[1], d_loss[2], g_loss[0], g_loss[1]))
                self.sample_images(x_val,Y_val,label_val,run_folder)
                #self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
                #self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                #self.save_model(run_folder)
            
            self.epoch+=1



    def sample_images(self,x_val,Y_val,label_val, run_folder):
        # Test
        r, c = 4, 4

        latent_code = x_val[:100,:]
        y_true = Y_val[:100,:,:,:]
        label_true = label_val[:100,:,:,:]
        gen_imgs = self.generator.predict([latent_code,label_true])
        # Perceptual loss
        #perceptloss = get_perceptual_loss(y_true,gen_imgs)

        indx = np.random.choice(y_true.shape[0], int(0.5*c*r) ,replace=False)

        face_real = 0.5*(y_true[indx]+1)
        face_real = face_real[:,:,:,[2,1,0]]

        gen_imgs = 0.5 * (gen_imgs[indx] + 1)
        gen_imgs = np.clip(gen_imgs, 0, 1)
        gen_imgs = gen_imgs[:,:,:,[2,1,0]]


        fig, axs = plt.subplots(r, c, figsize=(15,15))
        #fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
        cnt = 0
        for i in range(r):
            for j in range(int(0.5*c)):
                axs[i,2*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
                axs[i,2*j].axis('off')
                axs[i,2*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
                axs[i,2*j+1].axis('off')
                cnt += 1
        fig.savefig(os.path.join(run_folder, "images_latent/sample_%d.png" % self.epoch))
        plt.close()
    

In [1]:
GAN = WGANGP()

# Load data train

In [9]:
DATA_DIR = './data_train/data_split_v2/'
with open(os.path.join(DATA_DIR,"data_train_face_train_3channels.pkl"), "rb") as input_file:
    data_train_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"latentcode_finger_train_3channels.pkl"), "rb") as input_file:
    data_train_latentcode = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_train_finger_train_3channels.pkl"), "rb") as input_file:
    data_train_finger = pickle.load(input_file)

    
with open(os.path.join(DATA_DIR,"data_train_face_val_3channels.pkl"), "rb") as input_file:
    data_val_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"latentcode_finger_val_3channels.pkl"), "rb") as input_file:
    data_val_latentcode = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_train_finger_val_3channels.pkl"), "rb") as input_file:
    data_val_finger = pickle.load(input_file)
    

In [10]:
from keras.preprocessing.image import img_to_array, array_to_img
def resize_data(data):
    images_arr = []
    for img in data:
        #print(img.shape)
        img = array_to_img(img)
        resized_img = img.resize(size=(128, 128))
        images_arr.append(img_to_array(resized_img))
    data = np.asarray(images_arr)
    return data

In [11]:
data_train_finger = resize_data(data_train_finger)
data_val_finger = resize_data(data_val_finger)

In [12]:
print(np.max(data_train_finger),np.median(data_train_finger),np.min(data_train_finger))
print(np.max(data_train_face),np.median(data_train_face),np.min(data_train_face))

255.0 16.0 0.0
255.0 148.0 0.0


In [13]:
data_train_finger = ((data_train_finger)/np.max(data_train_finger))
data_val_finger = ((data_val_finger)/np.max(data_val_finger))

In [14]:
data_train_finger = np.where(data_train_finger > .5, 1.0, -1.0).astype('float32')
data_val_finger = np.where(data_val_finger > .5, 1.0, -1.0).astype('float32')

In [15]:
data_train_finger.shape, data_val_finger.shape

((1621, 128, 128, 1), (150, 128, 128, 1))

In [16]:
data_train_face = ((data_train_face-127.5)/127.5).astype('float32')

In [17]:
data_train_face.shape,data_train_latentcode[2].shape

((1621, 128, 128, 3), (1621, 128))

In [18]:
data_val_face = ((data_val_face-127.5)/127.5).astype('float32')

In [19]:
data_val_face.shape,data_val_latentcode[2].shape

((150, 128, 128, 3), (150, 128))

In [20]:
print(np.max(data_train_finger),np.median(data_train_finger),np.min(data_train_finger))
print(np.max(data_train_face),np.median(data_train_face),np.min(data_train_face))

1.0 -1.0 -1.0
1.0 0.16078432 -1.0


In [21]:
train_latentcode = data_train_latentcode[2]
train_face = data_train_face
valid_latentcode = data_val_latentcode[2]
valid_face = data_val_face
label_train_finger = data_train_finger
label_val_finger = data_val_finger

In [22]:
# from sklearn.model_selection import train_test_split
# train_latentcode,valid_latentcode,train_face,valid_face = train_test_split(data_train_latentcode[2],
#                                                              data_train_face,
#                                                              test_size=0.2,
#                                                              random_state=13)

In [23]:
train_latentcode.shape,valid_latentcode.shape,train_face.shape,valid_face.shape

((1621, 128), (150, 128), (1621, 128, 128, 3), (150, 128, 128, 3))

In [2]:
plt.subplot(1, 2, 1)
plt.imshow(data_train_finger[10],cmap='gray')
plt.subplot(1, 2, 2)
plt.imshow(data_train_face[10])
plt.show()

# Training GAN

In [3]:
GAN.train(  
      train_latentcode
    , train_face
    , valid_latentcode
    , valid_face
    , label_train_finger
    , label_val_finger
    , batch_size = 128
    , epochs = 10000
    , run_folder = './model_save'
    , print_every_n_batches = 50
    , n_critic = 5
    , using_generator = True
)

In [88]:
GAN.discriminator.save('./model_save/model/discriminator_3channel_new.h5')
GAN.generator.save('./model_save/model/generator_3channel_new.h5')

# Validation

In [19]:
gen = tf.keras.models.load_model('./model_save/model/generator_3channel_conditional.h5',custom_objects={'get_perceptual_loss': get_perceptual_loss})

In [4]:
def test(validate=True):    # Test
    r, c = 4, 4
    if validate :
        latent_code = valid_latentcode[:150,:]
        y_true = valid_face[:150,:,:,:]
        label_true = label_val_finger[:150,:,:,:]
    else:
        latent_code = train_latentcode[900:1000,:]
        y_true = train_face[900:1000,:,:,:]
        label_true = label_train_finger[900:1000,:,:,:]
    
    gen_imgs = GAN.generator.predict([latent_code,label_true])
    # Perceptual loss
    #perceptloss = get_perceptual_loss(y_true,gen_imgs)

    indx = np.random.choice(y_true.shape[0], int(0.5*c*r) ,replace=False)

    face_real = 0.5*(y_true[indx]+1)
    face_real = face_real[:,:,:,[2,1,0]]

    gen_imgs = 0.5 * (gen_imgs[indx] + 1)
    gen_imgs = np.clip(gen_imgs, 0, 1)
    gen_imgs = gen_imgs[:,:,:,[2,1,0]]


    fig, axs = plt.subplots(r, c, figsize=(15,15))
    #fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
    cnt = 0
    for i in range(r):
        for j in range(int(0.5*c)):
            axs[i,2*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
            axs[i,2*j].axis('off')
            axs[i,2*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
            axs[i,2*j+1].axis('off')
            cnt += 1

    plt.show()
    #fig.savefig(os.path.join('./model_save', "images/face_gens.png" ))
    plt.close()

In [5]:
test(validate=False)

In [29]:
#GAN.sample_images(train_latentcode,train_face,"./model_save")

In [38]:
DATA_DIR = './data_train/data_split/'
with open(os.path.join(DATA_DIR,"latentcode_val_augment.pkl"), "rb") as input_file:
    latentcode_augment = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"face_val_augment.pkl"), "rb") as input_file:
    face_augment = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"finger_val_augment.pkl"), "rb") as input_file:
    finger_augment = pickle.load(input_file)

In [39]:
latentcode_augment[2].shape,face_augment.shape,finger_augment.shape

((300, 128), (300, 128, 128, 3), (300, 128, 128, 1))

In [40]:
finger_augment = ((finger_augment)/np.max(finger_augment))
finger_augment = np.where(finger_augment > .5, 1.0, -1.0).astype('float32')
#finger_augment = ((finger_augment-127.5)/127.5).astype('float32')

In [6]:
r, c = 8, 4

latentcode_datatrain = train_latentcode[200:500,:]
y_true_datatrain = train_face[200:500,:,:,:]
label_true_datatrain = label_train_finger[200:500,:,:,:]
print(latentcode_datatrain.shape)
noise_gauss = tf.keras.backend.random_normal(shape=(300, 128))
gen_imgs_augment = GAN.generator.predict([latentcode_augment[2]+noise_gauss*0.2,finger_augment])

# Perceptual loss
perceptloss_augment = get_perceptual_loss(y_true_datatrain,gen_imgs_augment)

indx = np.random.choice(y_true_datatrain.shape[0], int(0.5*c*r) ,replace=False)

face_real = 0.5*(y_true_datatrain[indx]+1)
face_real = face_real[:,:,:,[2,1,0]]

gen_imgs = 0.5 * (gen_imgs_augment[indx] + 1)
gen_imgs = np.clip(gen_imgs, 0, 1)
gen_imgs = gen_imgs[:,:,:,[2,1,0]]


fig, axs = plt.subplots(r, c, figsize=(15,30))
#fig.suptitle("Perceptual loss : %.3f" %(perceptloss_augment))
cnt = 0
for i in range(r):
    for j in range(int(0.25*c)):
        axs[i,4*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
        axs[i,4*j].axis('off')
        axs[i,4*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
        axs[i,4*j+1].axis('off')
        axs[i,4*j+2].imshow(np.squeeze(finger_augment[indx][cnt, :,:,:]),cmap='gray')
        axs[i,4*j+2].axis('off')
        axs[i,4*j+3].imshow(np.squeeze(label_true_datatrain[indx][cnt, :,:,:]),cmap='gray')
        axs[i,4*j+3].axis('off')
    cnt += 1

plt.show()
fig.savefig(os.path.join('./model_save', "images/face_gens_1.png" ))
plt.close()

### 