**Hi :)**

**To run our program please follow the next steps:**

**1. Connect to your drive and upload and unzip the requires files(e.g. datasets)**

**2. Fill the globals section**

**3. Choose code fo running (according to dataset - Monet/Photos)**

**4. Choose main type(train or test)**

**Please connect to your drive for upload the files**

In [None]:
from google.colab import drive 
drive.mount('/content/gdrive', force_remount=True)

**Unzip Dataset [Example below]**

In [None]:
!unzip '/content/gdrive/MyDrive/DeepLearning/Datasets/arbi.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Datasets/Photos/train_photos.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Datasets/Photos/valid_photos.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Datasets/Monet/train_preview.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Datasets/Monet/valid_preview.zip'



**Unzip Trained models + Test files [Example below]**

In [None]:
!unzip '/content/gdrive/MyDrive/DeepLearning/Trained_Models/Separate_Models/Trained_Models.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Trained_Models/Combined_Models/Trained_Combined_Models.zip'
!unzip '/content/gdrive/MyDrive/DeepLearning/Test/test.zip'


**For running trained model - please fill the following globals**


In [None]:
# For running trained model - please fill the following globals
DATA_PATH = ""  # FOR EXAMPLE: "./test/monet_blocks"
TRAINED_MODEL_NAME = ""  # FOR EXAMPLE: "trained_model_monet_blocks.h5"


**For running training process - please fill the following globals**


In [None]:
# For running training process - please fill the following globals
MASKS_TO_TRAIN = ""  # FOR EXAMPLE: "all" or "center" or "blocks" or "region"
TRAIN_PATH = ""  # FOR EXAMPLE: "./train_photos" or "./train_preview"
VALID_PATH = ""  # FOR EXAMPLE: "./valid_photos" or "./valid_preview"
ARBITRARY_MASK_PATH = ""    #FOR EXAMPLE: "./arbi"

**Code for 'Photos' dataset**

In [None]:
# imports:
import os
import sys
import random
import numpy as np
from random import shuffle
from PIL import Image
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Conv2DTranspose, Conv2D, \
    MaxPooling2D, Activation, BatchNormalization, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from tensorflow.keras.layers.experimental.preprocessing import Resizing
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
from tensorflow import keras

# globals:
IMG_SIZE = 128
MASK_SIZE = 128
RGB_DIM = 3
EPOCHS = 300
IMAGE_SHAPE = (IMG_SIZE, IMG_SIZE, RGB_DIM)
MISS_SHAPE = (MASK_SIZE, MASK_SIZE, RGB_DIM)
LAMBDA_RECON = 0.999
LAMBDA_ADVR = 0.001
OPTIMIZER = Adam(0.0002, 0.5)
BATCH_SIZE = 64
DISC_OUTPUT_SIZE = 14
MIN_RANDOM_BLOCKS_AMOUNT = 3
MAX_RANDOM_BLOCKS_AMOUNT = 9
RANDOM_BLOCK_SIZE = 21
MIN_VALID_TOTAL_LOSS = sys.maxsize
arbitrary_masks = []
inverse_arbitrary_masks = []
valid_total_loss = []
train_total_loss = []


# The function inits discriminator according to the patch gan arcithecture
def init_discriminator_patch_gan():
    model = Sequential()
    # 1st Convolutional Layer
    model.add(Conv2D(64, kernel_size=4, strides=2, input_shape=MISS_SHAPE, padding="same"))
    model.add(LeakyReLU(alpha=0.2))

    # 2nd Convolutional Layer
    model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # 3rd Convolutional Layer
    model.add(Conv2D(256, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(tf.keras.layers.ZeroPadding2D())

    # 4th Convolutional Layer
    model.add(Conv2D(512, kernel_size=4, strides=1, padding="valid"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(tf.keras.layers.ZeroPadding2D())

    model.add(Conv2D(1, kernel_size=4, strides=1, padding="valid", activation='sigmoid'))

    model.summary()

    img = Input(shape=MISS_SHAPE)
    validity = model(img)

    return Model(img, validity)


# The function inits the generator
def init_generator():
    model = Sequential()
    init_encoder(model)
    init_decoder(model)
    model.summary()
    masked_img = Input(shape=IMAGE_SHAPE)
    gen_pred = model(masked_img)

    return Model(masked_img, gen_pred)


# The function inits the encoder
def init_encoder(model):
    # 1st Convolutional Layer
    model.add(Conv2D(filters=96, input_shape=(128, 128, 3), kernel_size=(11, 11), strides=(4, 4)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    # Pooling
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))

    # 2nd Convolutional Layer
    model.add(Conv2D(filters=256, kernel_size=(5, 5), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    # Pooling
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))

    # 3rd Convolutional Layer
    model.add(Conv2D(filters=384, kernel_size=(3, 3), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    # 4th Convolutional Layer
    model.add(Conv2D(filters=384, kernel_size=(3, 3), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    # 5th Convolutional Layer
    model.add(Conv2D(filters=256, kernel_size=(3, 3), padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    # Pooling
    model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))

    # Passing it to a dense layer
    model.add(Flatten())

    # 1st Dense Layer
    model.add(Dense(9216, input_shape=(128 * 128 * 3,)))
    model.add(Activation('relu'))
    model.add(Dropout(0.4))


# The function inits the decoder
def init_decoder(model):
    first_layer = Reshape((6, 6, 256))
    model.add(first_layer)

    # 1st Up-Convolutional Layer
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))

    # 2nd Up-Convolutional Layer
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 3rd Up-Convolutional Layer
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 4th Up-Convolutional Layer
    model.add(Conv2DTranspose(32, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 5th Up-Convolutional Layer
    model.add(Conv2DTranspose(RGB_DIM, kernel_size=4, strides=2, padding="same"))
    last_layer = Resizing(IMG_SIZE, IMG_SIZE)
    model.add(last_layer)
    model.add(Activation('tanh'))


# The function inits the full context encoder
def init_contextEncoder():
    # inits and compiles the discriminator
    discriminator = init_discriminator_patch_gan()
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=OPTIMIZER,
                          metrics=['accuracy'])
    # inits the generator
    generator = init_generator()
    masked_img = Input(shape=IMAGE_SHAPE)
    gen_pred = generator(masked_img)

    discriminator.trainable = False

    real = discriminator(gen_pred)

    combined = Model(masked_img, [gen_pred, real])
    combined.compile(loss=['mse', 'binary_crossentropy'],
                     loss_weights=[LAMBDA_RECON, LAMBDA_ADVR],
                     optimizer=OPTIMIZER)

    return generator, discriminator, combined


# The function creates center mask for image
def create_center_mask(photo):
    photo = photo.numpy()
    masked_img = photo.copy()
    masked_img_copy = photo.copy()
    missing_part = photo.copy()
    mask = photo.copy()
    x1 = y1 = int(IMG_SIZE / 4)  # 32
    x2 = y2 = int((IMG_SIZE / 4) * 3)  # 96

    missing_part[:] = 1
    missing_part[y1:y2, x1:x2, :] = masked_img_copy[y1:y2, x1:x2, :].copy()
    mask[:] = 0
    mask[y1:y2, x1:x2, :] = 1
    masked_img[y1:y2, x1:x2, :] = 0

    return masked_img, missing_part, photo, mask


# The function creates random blocks mask for image
def create_random_blocks_mask(photo):
    photo = photo.numpy()
    masked_img = photo.copy()
    masked_img_copy = photo.copy()
    missing_part = photo.copy()
    missing_part[:] = 1
    mask = photo.copy()
    mask[:] = 0
    blocks_amount = random.randint(MIN_RANDOM_BLOCKS_AMOUNT, MAX_RANDOM_BLOCKS_AMOUNT)
    for i in range(blocks_amount):
        y1 = random.randint(0, IMG_SIZE - RANDOM_BLOCK_SIZE)  # (0, 107)
        y2 = y1 + RANDOM_BLOCK_SIZE
        x1 = random.randint(0, IMG_SIZE - RANDOM_BLOCK_SIZE)  # (0, 107)
        x2 = x1 + RANDOM_BLOCK_SIZE

        missing_part[y1:y2, x1:x2, :] = masked_img_copy[y1:y2, x1:x2, :].copy()
        mask[y1:y2, x1:x2, :] = 1
        masked_img[y1:y2, x1:x2, :] = 0

    return masked_img, missing_part, photo, mask


# The function creates arbitrary mask for image
def create_arbitrary_mask(photo):
    global arbitrary_masks, inverse_arbitrary_masks

    arbi_mask_idx = np.random.randint(0, len(arbitrary_masks))
    arbitrary_mask = arbitrary_masks[arbi_mask_idx]
    inverse_arbitrary_mask = inverse_arbitrary_masks[arbi_mask_idx]

    masked_img = np.multiply(inverse_arbitrary_mask, photo)
    missing_part = np.multiply(arbitrary_mask, photo)

    return masked_img, missing_part, photo, arbitrary_mask


# The function creates and returns the inverse mask
def get_arbitrary_mask():
    global arbitrary_masks, inverse_arbitrary_masks
    masks_photos_list = os.listdir(ARBITRARY_MASK_PATH + "/")
    for mask in masks_photos_list:
        arbitrary_mask_name = ARBITRARY_MASK_PATH + "/" + mask
        arbi_mask = Image.open(arbitrary_mask_name).convert("L")
        arbi_mask = arbi_mask.resize((IMG_SIZE, IMG_SIZE), Image.ANTIALIAS)
        arbi_mask = np.asarray(arbi_mask)
        arbi_mask, inverse_arbi_mask = get_inverse_mask(arbi_mask)

        dim1 = np.zeros((IMG_SIZE, IMG_SIZE))
        dim2 = np.add(dim1, arbi_mask)
        arbitrary_mask_3d = np.stack((arbi_mask, dim2, dim2), axis=2)

        dim3 = np.zeros((IMG_SIZE, IMG_SIZE))
        dim4 = np.add(dim3, inverse_arbi_mask)
        inverse_arbitrary_mask_3d = np.stack((inverse_arbi_mask, dim4, dim4), axis=2)

        arbitrary_masks.append(arbitrary_mask_3d)
        inverse_arbitrary_masks.append(inverse_arbitrary_mask_3d)

    arbitrary_masks = np.asarray(arbitrary_masks)
    inverse_arbitrary_masks = np.asarray(inverse_arbitrary_masks)


# The function makes the same shuffle to train_x and train_y and train_z
def shuffle_data(train_x, train_y, train_z):
    shuffler = np.random.permutation(len(train_x))
    train_x_shuffled = train_x[shuffler]
    train_y_shuffled = train_y[shuffler]
    train_z_shuffled = train_z[shuffler]
    return train_x_shuffled, train_y_shuffled, train_z_shuffled


# The function loads data
def load_data(photos_path):
    photos_list = os.listdir(photos_path)  # e.g. "./train_photos"

    resize_and_rescale = tf.keras.Sequential([
        layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE)
        , layers.experimental.preprocessing.Rescaling(1. / 255)
    ])

    masked_imgs = []
    missing_parts = []
    photos = []
    masks = []

    shuffle(photos_list)

    for photo in photos_list:
        sample_photo = plt.imread(os.path.join(photos_path, photo))
        edited_photo = resize_and_rescale(sample_photo)
        if MASKS_TO_TRAIN == "all":
            mask_idx = random.randint(0, 2)
            if mask_idx == 0:
                masked_img, missing_part, photo, mask = create_center_mask(edited_photo)
            elif mask_idx == 1:
                masked_img, missing_part, photo, mask = create_random_blocks_mask(edited_photo)
            else:
                masked_img, missing_part, photo, mask = create_arbitrary_mask(edited_photo)
        elif MASKS_TO_TRAIN == "center":
            masked_img, missing_part, photo, mask = create_center_mask(edited_photo)
        elif MASKS_TO_TRAIN == "blocks":
            masked_img, missing_part, photo, mask = create_random_blocks_mask(edited_photo)
        else:
            masked_img, missing_part, photo, mask = create_arbitrary_mask(edited_photo)
        masked_imgs.append(masked_img)
        missing_parts.append(missing_part)
        photos.append(photo)
        masks.append(mask)

    masked_imgs = np.asarray(masked_imgs)
    missing_parts = np.asarray(missing_parts)
    photos = np.asarray(photos)
    masks = np.asarray(masks)

    return masked_imgs, missing_parts, photos, masks


# The function creates and returns the inverse mask
def get_inverse_mask(arbitrary_mask):
    inverse_arbitrary_mask_2d = np.copy(arbitrary_mask)
    arbitrary_mask_2d = np.copy(arbitrary_mask)
    for i in range(IMG_SIZE):
        for j in range(IMG_SIZE):
            if inverse_arbitrary_mask_2d[i][j] != 0:
                inverse_arbitrary_mask_2d[i][j] = 0
                arbitrary_mask_2d[i][j] = 1
            else:
                inverse_arbitrary_mask_2d[i][j] = 1
                arbitrary_mask_2d[i][j] = 0
    return arbitrary_mask_2d, inverse_arbitrary_mask_2d


# The function tests on the validation set
def test(generator, masked_imgs_valid, missing_parts_valid, photos_valid, masks_valid, batch_size):
    masked_imgs_valid = masked_imgs_valid[:batch_size]
    missing_parts_valid = missing_parts_valid[:batch_size]
    photos_valid = photos_valid[:batch_size]
    masks_valid = masks_valid[:batch_size]
    gen_pred_valid = generator.predict(masked_imgs_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    for i in range(7):
        f, ax = plt.subplots(1, 6, figsize=(25, 25))
        ax[0].imshow(photos_valid[i])
        ax[0].set_title("photo_valid")
        ax[1].imshow(masked_imgs_valid[i])
        ax[1].set_title("masked_imgs")
        ax[2].imshow(gen_pred_valid[i])
        ax[2].set_title("gen_pred valid")
        ax[3].imshow(missing_parts_valid[i])
        ax[3].set_title("missing_part")
        ax[4].imshow(y_hat_multi_mask[i])
        ax[4].set_title("y_hat_multi_mask")
        ax[5].imshow(filled_imgs_valid[i])
        ax[5].set_title("filled_imgs_valid2")
        plt.show()


# The function tests on the validation set
def test_on_batch(combined_model, masked_imgs_valid, photos_valid, valid, epoch):
    global valid_total_loss, train_total_loss
    masked_imgs_valid_tmp = masked_imgs_valid
    photos_valid_tmp = photos_valid
    batches_amount = int(len(masked_imgs_valid) / BATCH_SIZE) - 1
    sum_total_loss_valid = 0
    for i in range(batches_amount):
        batch_size_start = BATCH_SIZE * i
        batch_size_end = batch_size_start + BATCH_SIZE
        masked_imgs_valid_batch = masked_imgs_valid_tmp[batch_size_start:batch_size_end]
        photos_valid_batch = photos_valid_tmp[batch_size_start:batch_size_end]
        g_valid_loss = combined_model.test_on_batch(masked_imgs_valid_batch, [photos_valid_batch, valid])
        sum_total_loss_valid += g_valid_loss[0]

        print("%d , valid_loss_1: %f , valid_loss_2: %f" % (epoch, g_valid_loss[0], g_valid_loss[1]))

    avg_total_loss = sum_total_loss_valid / batches_amount
    valid_total_loss.append(avg_total_loss)
    # Plot the validation progress
    if epoch % 5 == 0:
        show_train_and_valid_graph()


# The function presents graphic results
def show_train_and_valid_graph():
    plt.title('Model Total Loss')
    plt.ylabel('Total loss')
    plt.xlabel('epoch')
    plt.plot(valid_total_loss, label="valid")
    plt.plot(train_total_loss, label="train")
    plt.legend()
    plt.show()


# The function trains the model
def train(combined_model, generator, discriminator, masked_imgs_train, missing_parts_train, masked_imgs_valid,
          missing_parts_valid, photos_train, photos_valid, masks_train, masks_valid, epochs, batch_size):
    global train_total_loss
    batches_amount = int(len(masked_imgs_train) / BATCH_SIZE) - 1
    real = np.ones((batch_size, DISC_OUTPUT_SIZE, DISC_OUTPUT_SIZE, 1))  # (64, 14, 14, 1)
    fake = np.zeros((batch_size, DISC_OUTPUT_SIZE, DISC_OUTPUT_SIZE, 1))  # (64, 14, 14, 1)
    for epoch in range(epochs):
        sum_total_loss_train = 0
        # train discriminator
        masked_imgs_train_shuff, missing_parts_train_shuff, photos_train_shuff = shuffle_data(masked_imgs_train,
                                                                                              missing_parts_train,
                                                                                              photos_train)
        masked_imgs_train_tmp = masked_imgs_train_shuff
        photos_train_tmp = photos_train_shuff
        for i in range(batches_amount):
            batch_size_start = BATCH_SIZE * i
            batch_size_end = batch_size_start + BATCH_SIZE
            masked_imgs = masked_imgs_train_tmp[batch_size_start:batch_size_end]
            photos_train_batch = photos_train_tmp[batch_size_start:batch_size_end]

            # the prediction result
            gen_pred = generator.predict(masked_imgs)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(photos_train_batch, real)
            d_loss_fake = discriminator.train_on_batch(gen_pred, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train the Generator
            g_loss = combined_model.train_on_batch(masked_imgs, [photos_train_batch, real])

            sum_total_loss_train += g_loss[0]

            print(" Epoch:%d -- Discrimnator loss: %f -- Accuracy: %.2f%% -- Generator loss: %f -- Mse: %f" % (
                epoch, d_loss[0], 100 * d_loss[1], g_loss[0], g_loss[1]))

        avg_total_loss = sum_total_loss_train / batches_amount
        train_total_loss.append(avg_total_loss)

        test(generator, masked_imgs_valid, missing_parts_valid, photos_valid, masks_valid, batch_size)
        test_on_batch(combined_model, masked_imgs_valid, photos_valid, real, epoch)
        save_best_model(generator, epoch)


# The function loads the best model found
def load_best_model(masked_imgs_valid, photos_valid, masks_valid):
    model = keras.models.load_model("best_model.h5")
    masked_imgs_valid = masked_imgs_valid[:BATCH_SIZE]
    photos_valid = photos_valid[:BATCH_SIZE]
    masks_valid = masks_valid[:BATCH_SIZE]
    gen_pred_valid = model.predict(masked_imgs_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    print("inside test: ")
    for i in range(7):
        f, ax = plt.subplots(1, 5, figsize=(25, 25))
        ax[0].imshow(photos_valid[i])
        ax[0].set_title("photo_valid")
        ax[1].imshow(masked_imgs_valid[i])
        ax[1].set_title("masked_imgs")
        ax[2].imshow(gen_pred_valid[i])
        ax[2].set_title("gen_pred_valid")
        ax[3].imshow(y_hat_multi_mask[i])
        ax[3].set_title("y_hat_multi_mask")
        ax[4].imshow(filled_imgs_valid[i])
        ax[4].set_title("filled_imgs_valid2")
        plt.show()


# The function saves the best model found
def save_best_model(model, epoch):
    global MIN_VALID_TOTAL_LOSS
    if valid_total_loss[epoch] < MIN_VALID_TOTAL_LOSS:
        MIN_VALID_TOTAL_LOSS = valid_total_loss[epoch]
        print("----------------------------------------- saved model ------------------------------------------")
        model.save("best_model.h5")


def test_data():
    dir_name = DATA_PATH
    dir_list = os.listdir(dir_name)
    dir_list = sorted(dir_list)

    resize_and_rescale = tf.keras.Sequential([
        layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE)
        , layers.experimental.preprocessing.Rescaling(1. / 255)
    ])

    masked_imgs_valid = []
    photos_valid = []
    masks_valid = []

    for idx, item in enumerate(dir_list):
        if item == ".DS_Store":
            continue
        if idx % 2 == 0:
            if dir_name == "./test/photos_region":
                photo_path = dir_name + "/" + dir_list[idx - 1]
                mask_path = dir_name + "/" + dir_list[idx]
            elif dir_name == "./test/photos_blocks" or dir_name == "./test/monet_central_block" or dir_name == "./test/monet_blocks":
                photo_path = dir_name + "/" + dir_list[idx]
                mask_path = dir_name + "/" + dir_list[idx - 1]
            else:
                photo_path = dir_name + "/" + dir_list[idx - 1]
                mask_path = dir_name + "/" + dir_list[idx]
            photo = plt.imread(photo_path)
            edited_photo = resize_and_rescale(photo)
            mask = plt.imread(mask_path)
            edited_mask = resize_and_rescale(mask)
            edited_photo = edited_photo - edited_mask
            masked_imgs_valid.append(edited_photo)
            photos_valid.append(edited_photo)
            masks_valid.append(edited_mask)

    masked_imgs_valid = np.asarray(masked_imgs_valid)
    photos_valid = np.asarray(photos_valid)
    masks_valid = np.asarray(masks_valid)

    model = keras.models.load_model(TRAINED_MODEL_NAME)
    gen_pred_valid = model.predict(masked_imgs_valid)
    gen_pred_valid = np.asarray(gen_pred_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    for i in range(len(masked_imgs_valid)):
        f, ax = plt.subplots(1, 2, figsize=(10, 10))
        ax[0].imshow(masked_imgs_valid[i])
        ax[1].imshow(filled_imgs_valid[i])
        plt.show()


**Code for 'Monet' dataset**

In [None]:
# imports:
import os
import sys
import random
import numpy as np
from random import shuffle
from PIL import Image
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Conv2DTranspose, Conv2D, \
    MaxPooling2D, Activation, BatchNormalization, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from tensorflow.keras.layers.experimental.preprocessing import Resizing
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
from tensorflow import keras

# globals:
IMG_SIZE = 128
MASK_SIZE = 128
RGB_DIM = 3
EPOCHS = 300
IMAGE_SHAPE = (IMG_SIZE, IMG_SIZE, RGB_DIM)
MISS_SHAPE = (MASK_SIZE, MASK_SIZE, RGB_DIM)
LAMBDA_RECON = 0.999
LAMBDA_ADVR = 0.001
LAMBDA_STYLE = 0.01
OPTIMIZER = Adam(0.0002, 0.5)
BATCH_SIZE = 64
DISC_OUTPUT_SIZE = 14
MIN_RANDOM_BLOCKS_AMOUNT = 3
MAX_RANDOM_BLOCKS_AMOUNT = 9
RANDOM_BLOCK_SIZE = 21
MIN_VALID_TOTAL_LOSS = sys.maxsize
arbitrary_masks = []
inverse_arbitrary_masks = []
valid_total_loss = []
train_total_loss = []
layers_output = dict()
feature_extractor = None


# The function calcculate the gram matrix
def get_gram_matrix(input_x):
    input_x = tf.transpose(input_x, (2, 0, 1))
    features = tf.reshape(input_x, (tf.shape(input_x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))
    return gram


# The function calculate the style loss for one layer
def get_style_loss(style_features, pred_features):
    style_gram_matrix = get_gram_matrix(style_features)
    original_gram_matrix = get_gram_matrix(pred_features)
    size = IMG_SIZE * IMG_SIZE
    return tf.reduce_sum(tf.square(style_gram_matrix - original_gram_matrix)) / (4.0 * (RGB_DIM ** 2) * (size ** 2))


# The function return the style loss for all layers
def style_loss(y_true, y_pred):
    combined_y = tf.concat([y_true, y_pred], axis=0)
    features = feature_extractor(combined_y)
    style_loss_val = tf.zeros(shape=())

    # extract the style loss from the generator layers
    for layer_name in layers_output:
        layer_features = features[layer_name]
        style_features = layer_features[0, :, :, :]
        pred_features = layer_features[1, :, :, :]
        layer_style_loss = get_style_loss(style_features, pred_features)
        style_loss_val += (1 / len(layers_output)) * layer_style_loss

    return style_loss_val


# The function inits discriminator according to the patch gan architecture
def init_discriminator_patch_gan():
    model = Sequential()
    # 1st Convolutional Layer
    model.add(Conv2D(64, kernel_size=4, strides=2, input_shape=MISS_SHAPE, padding="same"))
    model.add(LeakyReLU(alpha=0.2))

    # 2nd Convolutional Layer
    model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # 3rd Convolutional Layer
    model.add(Conv2D(256, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(tf.keras.layers.ZeroPadding2D())

    # 4th Convolutional Layer
    model.add(Conv2D(512, kernel_size=4, strides=1, padding="valid"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(tf.keras.layers.ZeroPadding2D())

    model.add(Conv2D(1, kernel_size=4, strides=1, padding="valid", activation='sigmoid'))

    model.summary()

    img = Input(shape=MISS_SHAPE)
    validity = model(img)

    return Model(img, validity)


# The function init generator with 4 convolution layers and patch gan
def init_generator():
    global layers_output, feature_extractor
    model = Sequential()
    init_encoder(model)

    layers_output = dict([(layer.name, layer.output) for layer in model.layers])
    feature_extractor = Model(inputs=model.inputs, outputs=layers_output)

    # bottleneck
    model.add(Conv2D(4000, kernel_size=4, strides=1, padding="valid"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    init_decoder(model)

    model.summary()

    masked_img = Input(shape=IMAGE_SHAPE)
    gen_pred = model(masked_img)

    return Model(masked_img, gen_pred)


# The function inits the encoder
def init_encoder(model):
    # 1st Convolutional Layer
    model.add(Conv2D(64, kernel_size=4, strides=2, input_shape=IMAGE_SHAPE, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # 2nd Convolutional Layer
    model.add(Conv2D(64, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # 3rd Convolutional Layer
    model.add(Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # 4th Convolutional Layer
    model.add(Conv2D(256, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.5))

    # 5th Convolutional Layer
    model.add(Conv2D(512, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))

    # 6th Convolutional Layer
    model.add(Conv2D(512, kernel_size=4, strides=1, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))


# The function inits  the decoder
def init_decoder(model):
    # 1st Up-Convolutional Layer
    model.add(Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"))

    # 2nd Up-Convolutional Layer
    model.add(Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 3rd Up-Convolutional Layer
    model.add(Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 4th Up-Convolutional Layer
    model.add(Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 5th Up-Convolutional Layer
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 6th Up-Convolutional Layer
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))

    # 7th Up-Convolutional Layer
    model.add(Conv2DTranspose(RGB_DIM, kernel_size=4, strides=2, padding="same"))
    model.add(Activation('tanh'))


# The function init the full context encoder
def init_contextEncoder():
    # init and compile the discriminator
    discriminator = init_discriminator_patch_gan()
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=OPTIMIZER,
                          metrics=['accuracy'])
    # init the generator
    generator = init_generator()
    masked_img = Input(shape=IMAGE_SHAPE)
    gen_pred = generator(masked_img)

    discriminator.trainable = False

    real = discriminator(gen_pred)

    combined = Model(masked_img, [gen_pred, real])
    combined.compile(loss=['mse', 'binary_crossentropy', style_loss],
                     loss_weights=[LAMBDA_RECON, LAMBDA_ADVR, LAMBDA_STYLE],
                     optimizer=OPTIMIZER)

    return generator, discriminator, combined


# The function creates center mask for image
def create_center_mask(photo):
    photo = photo.numpy()
    masked_img = photo.copy()
    masked_img_copy = photo.copy()
    missing_part = photo.copy()
    mask = photo.copy()
    x1 = y1 = int(IMG_SIZE / 4)  # 32
    x2 = y2 = int((IMG_SIZE / 4) * 3)  # 96

    missing_part[:] = 1
    missing_part[y1:y2, x1:x2, :] = masked_img_copy[y1:y2, x1:x2, :].copy()
    mask[:] = 0
    mask[y1:y2, x1:x2, :] = 1
    masked_img[y1:y2, x1:x2, :] = 0

    return masked_img, missing_part, photo, mask


# The function creates random blocks mask for image
def create_random_blocks_mask(photo):
    photo = photo.numpy()
    masked_img = photo.copy()
    masked_img_copy = photo.copy()
    missing_part = photo.copy()
    missing_part[:] = 1
    mask = photo.copy()
    mask[:] = 0
    blocks_amount = random.randint(MIN_RANDOM_BLOCKS_AMOUNT, MAX_RANDOM_BLOCKS_AMOUNT)
    for i in range(blocks_amount):
        y1 = random.randint(0, IMG_SIZE - RANDOM_BLOCK_SIZE)  # (0, 107)
        y2 = y1 + RANDOM_BLOCK_SIZE
        x1 = random.randint(0, IMG_SIZE - RANDOM_BLOCK_SIZE)  # (0, 107)
        x2 = x1 + RANDOM_BLOCK_SIZE

        missing_part[y1:y2, x1:x2, :] = masked_img_copy[y1:y2, x1:x2, :].copy()
        mask[y1:y2, x1:x2, :] = 1
        masked_img[y1:y2, x1:x2, :] = 0

    return masked_img, missing_part, photo, mask


# The function creates arbitrary mask for image
def create_arbitrary_mask(photo):
    global arbitrary_masks, inverse_arbitrary_masks

    arbi_mask_idx = np.random.randint(0, len(arbitrary_masks))
    arbitrary_mask = arbitrary_masks[arbi_mask_idx]
    inverse_arbitrary_mask = inverse_arbitrary_masks[arbi_mask_idx]

    masked_img = np.multiply(inverse_arbitrary_mask, photo)
    missing_part = np.multiply(arbitrary_mask, photo)

    return masked_img, missing_part, photo, arbitrary_mask


# The function creates and returns the inverse mask
def get_arbitrary_mask():
    global arbitrary_masks, inverse_arbitrary_masks
    masks_photos_list = os.listdir(ARBITRARY_MASK_PATH + "/")
    for mask in masks_photos_list:
        arbitrary_mask_name = ARBITRARY_MASK_PATH + "/" + mask
        arbi_mask = Image.open(arbitrary_mask_name).convert("L")
        arbi_mask = arbi_mask.resize((IMG_SIZE, IMG_SIZE), Image.ANTIALIAS)
        arbi_mask = np.asarray(arbi_mask)
        arbi_mask, inverse_arbi_mask = get_inverse_mask(arbi_mask)

        dim1 = np.zeros((IMG_SIZE, IMG_SIZE))
        dim2 = np.add(dim1, arbi_mask)
        arbitrary_mask_3d = np.stack((arbi_mask, dim2, dim2), axis=2)

        dim3 = np.zeros((IMG_SIZE, IMG_SIZE))
        dim4 = np.add(dim3, inverse_arbi_mask)
        inverse_arbitrary_mask_3d = np.stack((inverse_arbi_mask, dim4, dim4), axis=2)

        arbitrary_masks.append(arbitrary_mask_3d)
        inverse_arbitrary_masks.append(inverse_arbitrary_mask_3d)

    arbitrary_masks = np.asarray(arbitrary_masks)
    inverse_arbitrary_masks = np.asarray(inverse_arbitrary_masks)



# The function makes the same shuffle to train_x and train_y and train_z
def shuffle_data(train_x, train_y, train_z):
    shuffler = np.random.permutation(len(train_x))
    train_x_shuffled = train_x[shuffler]
    train_y_shuffled = train_y[shuffler]
    train_z_shuffled = train_z[shuffler]
    return train_x_shuffled, train_y_shuffled, train_z_shuffled


# The function loads data
def load_data(photos_path):
    photos_list = os.listdir(photos_path)  # e.g. "./train_photos"

    resize_and_rescale = tf.keras.Sequential([
        layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE)
        , layers.experimental.preprocessing.Rescaling(1. / 255)
    ])

    masked_imgs = []
    missing_parts = []
    photos = []
    masks = []

    shuffle(photos_list)

    for photo in photos_list:
        sample_photo = plt.imread(os.path.join(photos_path, photo))
        edited_photo = resize_and_rescale(sample_photo)
        if MASKS_TO_TRAIN == "all":
            mask_idx = random.randint(0, 2)
            if mask_idx == 0:
                masked_img, missing_part, photo, mask = create_center_mask(edited_photo)
            elif mask_idx == 1:
                masked_img, missing_part, photo, mask = create_random_blocks_mask(edited_photo)
            else:
                masked_img, missing_part, photo, mask = create_arbitrary_mask(edited_photo)
        elif MASKS_TO_TRAIN == "center":
            masked_img, missing_part, photo, mask = create_center_mask(edited_photo)
        elif MASKS_TO_TRAIN == "blocks":
            masked_img, missing_part, photo, mask = create_random_blocks_mask(edited_photo)
        else:
            masked_img, missing_part, photo, mask = create_arbitrary_mask(edited_photo)
        masked_imgs.append(masked_img)
        missing_parts.append(missing_part)
        photos.append(photo)
        masks.append(mask)

    masked_imgs = np.asarray(masked_imgs)
    missing_parts = np.asarray(missing_parts)
    photos = np.asarray(photos)
    masks = np.asarray(masks)

    return masked_imgs, missing_parts, photos, masks


# The function creates and returns the inverse mask
def get_inverse_mask(arbitrary_mask):
    inverse_arbitrary_mask_2d = np.copy(arbitrary_mask)
    arbitrary_mask_2d = np.copy(arbitrary_mask)
    for i in range(IMG_SIZE):
        for j in range(IMG_SIZE):
            if inverse_arbitrary_mask_2d[i][j] != 0:
                inverse_arbitrary_mask_2d[i][j] = 0
                arbitrary_mask_2d[i][j] = 1
            else:
                inverse_arbitrary_mask_2d[i][j] = 1
                arbitrary_mask_2d[i][j] = 0
    return arbitrary_mask_2d, inverse_arbitrary_mask_2d


# The function tests on the validation set
def test(generator, masked_imgs_valid, missing_parts_valid, photos_valid, masks_valid, batch_size):
    masked_imgs_valid = masked_imgs_valid[:batch_size]
    missing_parts_valid = missing_parts_valid[:batch_size]
    photos_valid = photos_valid[:batch_size]
    masks_valid = masks_valid[:batch_size]
    gen_pred_valid = generator.predict(masked_imgs_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    for i in range(7):
        f, ax = plt.subplots(1, 6, figsize=(25, 25))
        ax[0].imshow(photos_valid[i])
        ax[0].set_title("photo_valid")
        ax[1].imshow(masked_imgs_valid[i])
        ax[1].set_title("masked_imgs")
        ax[2].imshow(gen_pred_valid[i])
        ax[2].set_title("gen_pred valid")
        ax[3].imshow(missing_parts_valid[i])
        ax[3].set_title("missing_part")
        ax[4].imshow(y_hat_multi_mask[i])
        ax[4].set_title("y_hat_multi_mask")
        ax[5].imshow(filled_imgs_valid[i])
        ax[5].set_title("filled_imgs_valid2")
        plt.show()


# The function tests on the validation set
def test_on_batch(combined_model, masked_imgs_valid, photos_valid, valid, epoch):
    global valid_total_loss, train_total_loss
    masked_imgs_valid_tmp = masked_imgs_valid
    photos_valid_tmp = photos_valid
    batches_amount = int(len(masked_imgs_valid) / BATCH_SIZE) - 1
    sum_total_loss_valid = 0
    for i in range(batches_amount):
        batch_size_start = BATCH_SIZE * i
        batch_size_end = batch_size_start + BATCH_SIZE
        masked_imgs_valid_batch = masked_imgs_valid_tmp[batch_size_start:batch_size_end]
        photos_valid_batch = photos_valid_tmp[batch_size_start:batch_size_end]
        g_valid_loss = combined_model.test_on_batch(masked_imgs_valid_batch, [photos_valid_batch, valid])
        sum_total_loss_valid += g_valid_loss[0]

        print("%d , valid_loss_1: %f , valid_loss_2: %f" % (epoch, g_valid_loss[0], g_valid_loss[1]))

    avg_total_loss = sum_total_loss_valid / batches_amount
    valid_total_loss.append(avg_total_loss)
    # Plot the validation progress
    if epoch % 5 == 0:
        show_train_and_valid_graph()


# The function presents graphic results
def show_train_and_valid_graph():
    plt.title('Model Total Loss')
    plt.ylabel('Total loss')
    plt.xlabel('epoch')
    plt.plot(valid_total_loss, label="valid")
    plt.plot(train_total_loss, label="train")
    plt.legend()
    plt.show()


# The function trains the model
def train(combined_model, generator, discriminator, masked_imgs_train, missing_parts_train, masked_imgs_valid,
          missing_parts_valid, photos_train, photos_valid, masks_train, masks_valid, epochs, batch_size):
    global train_total_loss
    batches_amount = int(len(masked_imgs_train) / BATCH_SIZE) - 1
    real = np.ones((batch_size, DISC_OUTPUT_SIZE, DISC_OUTPUT_SIZE, 1))  # (64, 14, 14, 1)
    fake = np.zeros((batch_size, DISC_OUTPUT_SIZE, DISC_OUTPUT_SIZE, 1))  # (64, 14, 14, 1)
    for epoch in range(epochs):
        sum_total_loss_train = 0
        # train discriminator
        masked_imgs_train_shuff, missing_parts_train_shuff, photos_train_shuff = shuffle_data(masked_imgs_train,
                                                                                              missing_parts_train,
                                                                                              photos_train)
        masked_imgs_train_tmp = masked_imgs_train_shuff
        photos_train_tmp = photos_train_shuff
        for i in range(batches_amount):
            batch_size_start = BATCH_SIZE * i
            batch_size_end = batch_size_start + BATCH_SIZE
            masked_imgs = masked_imgs_train_tmp[batch_size_start:batch_size_end]
            photos_train_batch = photos_train_tmp[batch_size_start:batch_size_end]

            # the prediction result
            gen_pred = generator.predict(masked_imgs)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(photos_train_batch, real)
            d_loss_fake = discriminator.train_on_batch(gen_pred, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train the Generator
            g_loss = combined_model.train_on_batch(masked_imgs, [photos_train_batch, real])

            sum_total_loss_train += g_loss[0]

            print(" Epoch:%d -- Discrimnator loss: %f -- Accuracy: %.2f%% -- Generator loss: %f -- Mse: %f" % (
                epoch, d_loss[0], 100 * d_loss[1], g_loss[0], g_loss[1]))

        avg_total_loss = sum_total_loss_train / batches_amount
        train_total_loss.append(avg_total_loss)

        test(generator, masked_imgs_valid, missing_parts_valid, photos_valid, masks_valid, batch_size)
        test_on_batch(combined_model, masked_imgs_valid, photos_valid, real, epoch)
        save_best_model(generator, epoch)


# The function loads the best model found
def load_best_model(masked_imgs_valid, photos_valid, masks_valid):
    model = keras.models.load_model("best_model.h5")
    masked_imgs_valid = masked_imgs_valid[:BATCH_SIZE]
    photos_valid = photos_valid[:BATCH_SIZE]
    masks_valid = masks_valid[:BATCH_SIZE]
    gen_pred_valid = model.predict(masked_imgs_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    for i in range(7):
        f, ax = plt.subplots(1, 5, figsize=(25, 25))
        ax[0].imshow(photos_valid[i])
        ax[0].set_title("photo_valid")
        ax[1].imshow(masked_imgs_valid[i])
        ax[1].set_title("masked_imgs")
        ax[2].imshow(gen_pred_valid[i])
        ax[2].set_title("gen_pred_valid")
        ax[3].imshow(y_hat_multi_mask[i])
        ax[3].set_title("y_hat_multi_mask")
        ax[4].imshow(filled_imgs_valid[i])
        ax[4].set_title("filled_imgs_valid2")
        plt.show()


# The function saves the best model found
def save_best_model(model, epoch):
    global MIN_VALID_TOTAL_LOSS
    if valid_total_loss[epoch] < MIN_VALID_TOTAL_LOSS:
        MIN_VALID_TOTAL_LOSS = valid_total_loss[epoch]
        print("----------------------------------------- saved model ------------------------------------------")
        model.save("best_model.h5")


def test_data():
    dir_name = DATA_PATH
    dir_list = os.listdir(dir_name)
    dir_list = sorted(dir_list)

    resize_and_rescale = tf.keras.Sequential([
        layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE)
        , layers.experimental.preprocessing.Rescaling(1. / 255)
    ])

    masked_imgs_valid = []
    photos_valid = []
    masks_valid = []

    for idx, item in enumerate(dir_list):
        if item == ".DS_Store":
            continue
        if idx % 2 == 0:
            if dir_name == "./test/photos_region":
                photo_path = dir_name + "/" + dir_list[idx - 1]
                mask_path = dir_name + "/" + dir_list[idx]
            elif dir_name == "./test/photos_blocks" or dir_name == "./test/monet_central_block" or dir_name == "./test/monet_blocks":
                photo_path = dir_name + "/" + dir_list[idx]
                mask_path = dir_name + "/" + dir_list[idx - 1]
            else:
                photo_path = dir_name + "/" + dir_list[idx - 1]
                mask_path = dir_name + "/" + dir_list[idx]
            photo = plt.imread(photo_path)
            edited_photo = resize_and_rescale(photo)
            mask = plt.imread(mask_path)
            edited_mask = resize_and_rescale(mask)
            edited_photo = edited_photo - edited_mask
            masked_imgs_valid.append(edited_photo)
            photos_valid.append(edited_photo)
            masks_valid.append(edited_mask)

    masked_imgs_valid = np.asarray(masked_imgs_valid)
    photos_valid = np.asarray(photos_valid)
    masks_valid = np.asarray(masks_valid)

    model = keras.models.load_model(TRAINED_MODEL_NAME)
    gen_pred_valid = model.predict(masked_imgs_valid)
    gen_pred_valid = np.asarray(gen_pred_valid)
    y_hat_multi_mask = np.multiply(gen_pred_valid, masks_valid)
    filled_imgs_valid = np.add(y_hat_multi_mask, masked_imgs_valid)
    for i in range(len(masked_imgs_valid)):
        f, ax = plt.subplots(1, 2, figsize=(10, 10))
        ax[0].imshow(masked_imgs_valid[i])
        ax[1].imshow(filled_imgs_valid[i])
        plt.show()

**Run the desired main**


In [None]:
# Run the desired main - FOR EXAMPLE:
# main_test() for testing trained model
# or
# main_train() for training
def main_train():
    get_arbitrary_mask()
    masked_imgs_train, missing_parts_train, photos_train, masks_train = load_data(TRAIN_PATH)
    masked_imgs_valid, missing_parts_valid, photos_valid, masks_valid = load_data(VALID_PATH)
    generator, discriminator, combined_model = init_contextEncoder()
    train(combined_model, generator, discriminator, masked_imgs_train, missing_parts_train, masked_imgs_valid,
          missing_parts_valid, photos_train, photos_valid, masks_train, masks_valid, epochs=EPOCHS,
          batch_size=BATCH_SIZE)


def main_test():
    test_data()


In [None]:
main_test()

In [None]:
main_train()