# This notebook offers some experiments on the paper:


"Globally and Locally Consistent Image Completion", SATOSHI IIZUKA, EDGAR SIMO-SERRA,
HIROSHI ISHIKAWA

The employed code comes from https://github.com/akmtn/pytorch-siggraph2017-inpainting

It requires a PyTorch version below 1.0.

In [None]:
%load_ext autoreload
%autoreload 2

from src.models import _NetCompletion, _NetContext, completionnet_places2
from src.ablation import completionnet_ablation, copy_weights
from src.masking import run_draw
from src.inpaint import inpainting, inpainting2, load_network, random_mask
from src.inpaint import load_mask, load_data, post_processing
from src.train import load_dataset, train_random_mask, train_discriminator, get_networks
import torch
import torchvision.transforms as transforms
from torch.nn.modules.loss import BCELoss, MSELoss
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import cv2
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import os
from moviepy.editor import *
import urllib.request
import numpy as np

filename = "completionnet_places2.t7"
url = "http://hi.cs.waseda.ac.jp/~iizuka/data/completionnet_places2.t7"

if not os.path.isfile(filename):
    urllib.request.urlretrieve(url, filename)
    
model, datamean = load_network()

# 1. Simple tests on an image
Code from https://stackoverflow.com/a/36382158/4986615

Press ESC to quit the windows.

In [None]:
img = cv2.imread("images/bridge.jpg")
run_draw( img, "mask.png")

In [None]:
plt.figure()
mask = cv2.imread("mask.png")
plt.imshow(mask)
plt.axis('off')
plt.title("Mask")
plt.show()

In [None]:
M = load_mask("mask.png", output_shape=(600, 400))
I = load_data("images/flower.jpg", output_shape=(600, 400))
out = inpainting(model, datamean, I, M, postproc=False, skip=False)
vutils.save_image(out, 'out.png', normalize=True)

In [None]:
plt.figure(figsize=(20,10))
out_im = cv2.imread("out.png")[:,:,::-1] 
plt.imshow(out_im)
plt.axis('off')
plt.title("Output")
plt.show()

# 2. Computing on a short video

In [None]:
# Next lines are for downloading the required video from Youtube
if not os.path.exists("wwf_forest.mp4"):
    os.system("youtube-dl gpzuVt_mkKs -o wwf_forest.mp4")

clip = VideoFileClip("wwf_forest.mp4").subclip((0,6.0),(0,6.2))
w, h = 600, 400
clip = clip.resize( (w, h) )
clip.ipython_display(fps=20, loop=True, autoplay=True)
M = torch.FloatTensor(1, h, w).fill_(0.)
mask_w, mask_h = np.random.randint(60,100, 2)
px = np.random.randint(0, w-mask_w)
py = np.random.randint(0, h-mask_h)
M[:, py:py+mask_h, px:px+mask_w] = 1.


def inpainting_video(clip, model, datamean, M):
    
    def fl(gf,t):
        im = gf(t)
        h,w,d = im.shape
        im = im.transpose((2,0,1)).astype(np.float64)
        I = torch.from_numpy(im/255.).float()
        out = inpainting(model, datamean, I, M, postproc=False, skip=False).data.numpy()
        out = (out*255.).transpose((1,2,0)).astype(int)
        return out
    
    return clip.fl(fl)

clip_inpainted = clip.fx(inpainting_video, model, datamean, M)
# clip_inpainted.ipython_display(fps=20, loop=True, autoplay=True)
clip_inpainted.write_videofile('inpainted_forest.mp4', bitrate="3000k")

In [None]:
VideoFileClip("wwf_forest.mp4").save_frame("wood.jpg", t=26.0)

# 3. Computation of the loss function on a set of random masks
Here we want to draw a metric for the quality of the reconstruction.
We use a sum of a weighted MSE and a binary cross entropy as in the reference paper for training the generator.

In [None]:
wtl2 = 0.5
bce_loss = BCELoss()
mse_loss = MSELoss()
        
M = random_mask(output_shape=(600, 400))
I = load_data("images/bridge.jpg", output_shape=(600, 400))
out = inpainting(model, datamean, I, M, postproc=False)
out2 = out.float()      

error = wtl2*mse_loss(out2, I) + (1 - wtl2)*bce_loss(out2, I)
print("Normal:", error)

out_proc = post_processing(I, M, out)
out_proc2 = out_proc.float()      
error = wtl2*mse_loss(out_proc2, I) + (1 - wtl2)*bce_loss(out_proc2, I)
print("Post-processing:", error)

# Influence of the hole size

In [None]:
size = [20,30,50,80,130,210]
N = len(size)
w, h = 600, 400
masks = torch.FloatTensor(N, h, w).fill_(0.)
px = np.random.randint(110, 490)
py = np.random.randint(110, 290)
image = "forest"
I = load_data(f"images/{image}.jpg", output_shape=(w, h))

for i in range(N):
    half = int(size[i]/2)
    res_dir = os.path.join("results", "hole", image)
    os.makedirs(res_dir, exist_ok=True)
    masks[i, py-half:py+half, px-half:px+half] = 1.
    out = inpainting(model, datamean, I, masks[i:i+1], postproc=False)
    out = out.data.numpy()
    out = out.transpose((1,2,0))
    plt.figure(figsize=(20,10))
    plt.imshow(out)
    plt.axis('off')
    plt.savefig(os.path.join(res_dir, f'hole-{i}.png'), bbox_inches='tight')
    
    out = inpainting(model, datamean, I, masks[i:i+1], postproc=False, skip=True)
    out = out.data.numpy()
    out = out.transpose((1,2,0))
    plt.figure(figsize=(20,10))
    plt.imshow(out)
    plt.axis('off')
    plt.savefig(os.path.join(res_dir, f'hole-sanity-{i}.png'), bbox_inches='tight')


# Influence of local context

# Neural networks

The local and global discriminators were not open-sourced. They are implemented in `models.py`

In [None]:
completion = _NetCompletion()
summary(completion, input_size=(4, 512, 512))

In [None]:
from src.models import _NetContext
context = _NetContext()
summary(context, [(3, 128, 128), (3, 256, 256)])

# 4. Training

## 4.1. Visualize mask and patch
(Violet = global, green = local, yellow = hole)

In [None]:
from src.train import train_random_mask
mask, patch = train_random_mask(2)
m, p = mask[0,0], patch[0]
m[p[0]:p[2], p[1]:p[3]] += 1
plt.imshow(m.numpy())
plt.axis('off')
plt.show()

## 4.2. Training the discriminator

In [None]:
_, context = get_networks(cuda=False)
model, datamean = load_network()
dataloader = load_dataset(dataset="cifar10", dataroot="dataset/cifar10", batch_size=2)
train_discriminator(model, context, dataloader)

# 5. Removing neurons in the pre-trained model
Kind of an ablation study

In [None]:
dropout = 0.1
A = completionnet_places2
A.load_state_dict(torch.load('completionnet_places2.pth'))
B = completionnet_ablation(dropout)
copy_weights(A, B)

# activate dropout during eval
B.eval()
for m in B.modules():
    if m.__class__.__name__.startswith('Dropout'):
        m.train()