In [None]:
import signal
import torch
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow.keras import Input, Model
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten,Resizing,GlobalAveragePooling2D,concatenate, Lambda, Conv2DTranspose, Conv2D, BatchNormalization
import  tensorflow.keras.backend as K

In [None]:
discriminator_resnet = tf.keras.models.load_model('./models/discriminator_resnet')
discriminator_model = tf.keras.models.load_model('./models/discriminator')
generator = tf.keras.models.load_model('./models/generator')

for layer in discriminator_model.layers:
    layer.trainable = False

In [None]:
no_glasses_array = np.load('./data/T81-855_glasses_dataset/no_glasses.npy')
glasses_array = np.load('./data/T81-855_glasses_dataset/template_glasses.npy')

In [None]:
feature_array = no_glasses_array

target_reconstruction_array = np.array(glasses_array,dtype=np.float32)
target_adversarial_array = np.ones(2200).reshape(-1, 1)

In [None]:
class SamplingLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim),mean=0,stddev=1)
        return z_mean + K.exp(0.5 * z_log_var) * epsilon


def create_generator_model():
    input_image = Input(shape=(112, 112,3), name='no_glasses_face')
    preprocessed_image = preprocess_input(input_image)
    resnet_features = discriminator_resnet(preprocessed_image)
    
    # Encoder
    x1 = GlobalAveragePooling2D()(resnet_features)
    x2 = Dense(512,activation="relu")(x1)
    x3 = Dense(256,activation="relu")(x2)

    # Add Noise at Bottleneck (VAE)
    z_mean = Dense(128)(x3)
    z_log_var = Dense(128)(x3)
    z = SamplingLayer()([z_mean, z_log_var])

    # Decoder
    d1 = concatenate([z, x3])
    d2 = Dense(512, activation="relu")(x3)
    d3 = Dense(4096, activation="relu")(d2)

    # Reshape to prepare for Deconvolutions
    batch_size = tf.shape(d3)[0]
    x = tf.reshape(d3, [batch_size, 64, 64, 1])
    
    # Deconvolutions to get mask
    x = Conv2DTranspose(128, (3,3), strides=(1,1), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv2D(32, (17,17), padding='valid', activation='relu')(x)
    x = BatchNormalization()(x)
    mask_image = Conv2D(3, (3,3), padding='same', activation='sigmoid')(x)

    # Scale the mask
    mask_image = Lambda(lambda x: x * -70.0)(mask_image)

    # Add the mask to input image
    generated_image = tf.math.add(input_image, mask_image)

    return Model(inputs=input_image, outputs=generated_image)

In [None]:
input_image = Input(shape=(112, 112, 3), name='face_image')

generator_model = create_generator_model()
generated_image = generator_model(input_image)
# generated_image = generator(input_image)
processed_generated_image = preprocess_input(generated_image)
output = discriminator_model(processed_generated_image)

model = Model(inputs=input_image,outputs=[output,generated_image,generated_image])

def adversarial_loss(y_true, y_pred):
    # Adversarial Loss
    return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)*10

def recon_loss_no_glass(y_true, y_pred):    
    # Reconstruction Loss using Euclidean (L2) Distance
    l2_distance = K.sqrt(K.sum(K.square(y_true - y_pred), axis=[1,2,3]))
    recon_loss = K.mean(l2_distance)
    # return recon_loss
    return recon_loss/200

def recon_loss_glass(y_true, y_pred):    
    # Reconstruction Loss using Euclidean (L2) Distance
    l2_distance = K.sqrt(K.sum(K.square(y_true - y_pred), axis=[1,2,3]))
    recon_loss = K.mean(l2_distance)

    return recon_loss

# Compile the model using the custom loss
model.compile(optimizer=RMSprop(learning_rate=2e-5),
              loss=[adversarial_loss,recon_loss_no_glass,recon_loss_glass],
              metrics=['accuracy'])


In [None]:
class KeyboardInterruptCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.original_sigint_handler = None
        self._training_pid = None
        self.best_epoch = None
        self.best_weights = None

    def on_train_begin(self, logs=None):
        self.original_sigint_handler = signal.signal(signal.SIGINT, self.interrupt_training)
        self._training_pid = os.getpid()
        self.best_val_loss = float('inf')

    def on_train_end(self, logs=None):
        signal.signal(signal.SIGINT, self.original_sigint_handler)
        self.original_sigint_handler = None
        self._training_pid = None

    def on_epoch_end(self, epoch, logs=None):
        if logs['val_loss'] < self.best_val_loss:
            self.best_epoch = epoch
            self.best_val_loss = logs['val_loss']
            self.best_weights = self.model.get_weights()
        
        logs['best_val_loss'] = self.best_val_loss

    def interrupt_training(self, signum, frame):
        if self._training_pid == os.getpid():
            print(f"Keyboard interrupt detected. Restoring weights from Epoch {self.best_epoch+1}")
            self.model.set_weights(self.best_weights)
            self.model.stop_training = True            

keyboard_interrupt_callback = KeyboardInterruptCallback()

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    patience=20,
    verbose=1,
    restore_best_weights=True,
    monitor='val_loss'
)

callbacks=[keyboard_interrupt_callback, early_stopping_callback]

In [None]:
model.fit(
    x = feature_array,
    y = [target_adversarial_array,target_reconstruction_array_no_glass,target_reconstruction_array_glass],
    epochs=1000, 
    batch_size=16,
    validation_split=0.05,
    callbacks=callbacks
)

In [None]:
output = generator_model(np.expand_dims(no_glasses_array[0],axis=0))
output = np.array(tf.cast(output, tf.uint8))[0]

fig = plt.figure(figsize=[2.5,2.5])
plt.imshow(output)

In [None]:
generator_model.save('models/generator')

In [None]:
generated_glasses_array = generator_model.predict(no_glasses_array)
generated_glasses_array = np.array(tf.cast(generated_glasses_array, tf.uint8))

np.save('./data/T81-855_glasses_dataset/generated_glasses_1.npy', generated_glasses_array)