In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data 
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import random
import cv2
import time

from model import Completion_Network,Context_Discriminators
from train_model import train

In [None]:
completion_network = Completion_Network()
context_discriminators = Context_Discriminators()

mseloss = nn.MSELoss()
bceloss = nn.BCELoss()

completion_losses = []
discriminator_losses = []
joint_losses = []
test_loss_C = []

In [None]:
transform = transforms.Compose([ 
                                transforms.Resize(size=(160, 160)),
                                transforms.ToTensor(),
                               ])
     
sep = 190000 #trainset size

celeba_data = datasets.ImageFolder('./data_faces', transform=transform)
trainset, testset = torch.utils.data.random_split(celeba_data, lengths=(sep,len(celeba_data)-sep))

batch_size = 32

# subset to work on less values

trainset = Subset(trainset, np.arange(10000))
testset = Subset(testset, np.arange(1000))

# data loaders
train_loader = torch.utils.data.DataLoader(trainset,batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(testset,batch_size=batch_size)

In [None]:
completion_losses, discrimination_losses, joint_losses, test_loss_C = train(completion_network=completion_network, 
                                                                    context_discriminators=context_discriminators, 
                                                                    train_loader=train_loader, test_loader=test_loader, 
                                                                    n_epoch = 3, tc = 1, td = 1,
                                                                    test_period = 1)

# Save the completion model 

In [None]:
#save
model_save_name = 'completion_network.pt'
path = "{}".format(model_save_name)
torch.save(completion_network.state_dict(), path)

In [None]:
#load
model_save_name = 'completion_network.pt'
path = "{}".format(model_save_name)
completion_network_load = Completion_Network()
completion_network_load.load_state_dict(torch.load(path))

# Plots

In [None]:
x = np.arange(len(joint_losses))
plt.plot(x, joint_losses)
plt.show()

In [None]:
x = np.arange(len(discrimination_losses))
plt.plot(x, discrimination_losses)
plt.show()

In [None]:
x = np.arange(len(completion_losses))
plt.plot(x, completion_losses)
plt.show()

In [None]:
x = np.arange(len(test_loss_C))
plt.plot(x, test_loss_C)
plt.show()

# Postprocessing 

In [None]:
def postprocessing(masked_images, output_completion, mask, radius = 3):
    
    batch_size = masked_images.shape[0]
    result = []
    for i in range(batch_size):

        #fast marching method
        
        src = masked_images[i].numpy().transpose(1,2,0)
        src = src*255.
        src = src.astype(np.uint8)
               
        mask_i = mask[i].numpy().transpose(1,2,0)
        mask_i = mask_i*255.
        mask_i = mask_i.astype(np.uint8)        
    
        dst = cv2.inpaint(src, mask_i, radius, cv2.INPAINT_TELEA) 

        #Poisson blending

        src = output_completion[i].numpy().transpose(1,2,0)
        src = src*255.
        src = src.astype(np.uint8) 

        center = mask_center(mask_i)
        mask_i = np.repeat(mask_i, repeats=3, axis = 2)

        blend = cv2.seamlessClone(src, dst, mask_i, center, cv2.NORMAL_CLONE)
        blend = transforms.ToTensor()(blend).unsqueeze(0)
        result.append(blend)

    result = torch.cat(result)
    return result

def mask_center(mask):
    x = []
    y = []
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i, j, 0] == 255:
                x.append(j)
                y.append(i)        
    x_start = min(x)
    x_end = max(x)
    y_start = min(y)
    y_end = max(y)
    center = ((x_start + x_end) // 2, (y_start + y_end) // 2)
    return center  

In [None]:
with torch.no_grad():
    for i,(test_images, _) in enumerate(test_loader):
         if i == 1 :
            mask, _ = generate_mask(input_shape=(test_images.shape[0], 1, test_images.shape[2], test_images.shape[3]), patch=(48,48), size=([24,48], [24,48]), n_holes=2)
            masked_images = test_images - test_images * mask + mpv * mask
            masked_images = masked_images.cuda()
            mask = mask.cuda()
            input = torch.cat((masked_images, mask), dim=1)
            output = completion_network(input) 
            masked_images = masked_images.cpu()
            output = output.cpu()
            mask = mask.cpu()
            output_final = postprocessing(masked_images, output, mask)
            inputs = torch.cat((masked_images[0].unsqueeze(0), output_final[0].unsqueeze(0), test_images[0].unsqueeze(0)))
            imshow_2(torchvision.utils.make_grid(inputs))
            break