# Global-and-Local-Attention-Based-Free-Form-Image-Inpainting

[SayedNadim /Global-and-Local-Attention-Based-Free-Form-Image-Inpainting](https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting) combined with [mit-han-lab/data-efficient-gans ](https://github.com/mit-han-lab/data-efficient-gans)

Currently there are no pre-trained models.

In [None]:
!nvidia-smi

# Training

In [None]:
#@title Install miniconda and dependencies
!git clone https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
%cd /content/
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh
!chmod +x Miniconda3-4.5.4-Linux-x86_64.sh
!bash ./Miniconda3-4.5.4-Linux-x86_64.sh -b -f -p /usr/local
!conda install pytorch==1.1 cudatoolkit torchvision -c pytorch -y
!conda install ipykernel -y
#!pip install tensorboardX
!pip install PyYAML
!pip install OpenCV-python
!pip install scipy==1.1
!pip install tensorboardX==1.4
!pip install tensorboard==1.11.0

In [None]:
#@title config file
%%writefile /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/configs/config.yaml 
# data parameters
dataset_name: Test
data_with_subfolder: False # apperantly this needs to be false, or no files will be detected
train_data_path: /content/data
resume: False
checkpoint_dir: /content/checkpoints-training
batch_size: 1
image_shape: [256, 256, 3]
mask_shape: [128, 128]
mask_batch_same: True
max_delta_shape: [32, 32]
margin: [0, 0]
discounted_mask: True
spatial_discounting_gamma: 0.9
random_crop: True
mask_type: hole     # hole | mosaic
mosaic_unit_size: 12
save_image: 500

# training parameters
expname: v1
cuda: True
gpu_ids: [0]  # set the GPU ids to use, e.g. [0] or [1, 2]
num_workers: 4
lr: 0.0001
beta1: 0.5
beta2: 0.9
niter: 1000000
print_iter: 100
viz_iter: 100
viz_max_out: 16
snapshot_save_iter: 500

# loss weight
coarse_l1_alpha: 1.2
l1_loss_alpha: 1.2
ae_loss_alpha: 1.2
global_wgan_loss_alpha: 1.
gan_loss_alpha: 0.001
wgan_gp_lambda: 10

# network parameters
netG:
  input_dim: 5
  ngf: 32

netD:
  input_dim: 3
  ndf: 64

In [None]:
#@title Modifying train.py with differentiable augmentation (experimental)
%%writefile /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/scripts/trainer.py
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F

policy = 'color,translation,cutout'

def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.network import Generator, GlobalDis
from utils.logger import get_logger
from torch.autograd import Variable
from math import exp

logger = get_logger()


class Trainer(nn.Module):
    def __init__(self, config):
        super(Trainer, self).__init__()
        self.config = config
        self.use_cuda = self.config['cuda']
        self.device_ids = self.config['gpu_ids']

        self.netG = Generator(self.config['netG'], self.use_cuda, self.device_ids)
        self.globalD = GlobalDis(self.config['netD'], self.use_cuda, self.device_ids)

        self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=self.config['lr'],
                                            betas=(self.config['beta1'], self.config['beta2']))
        d_params = list(self.globalD.parameters())
        self.optimizer_d = torch.optim.Adam(d_params, lr=config['lr'],
                                            betas=(self.config['beta1'], self.config['beta2']))

        if self.use_cuda:
            self.netG.to(self.device_ids[0])
            self.globalD.to(self.device_ids[0])

        self.ssim = SSIM()

    def forward(self, x, masks, ground_truth):
        self.train()
        losses = {}

        x1, x2 = self.netG(x, masks)
        x1_inpaint = x1 * masks + x * (1. - masks)
        x2_inpaint = x2 * masks + x * (1. - masks)

        ## D part
        refine_real, refine_fake = self.dis_forward(self.globalD, ground_truth, x2_inpaint.detach())
        losses['d_loss_loren'] = torch.mean(torch.log(1.0 + torch.abs(refine_real - refine_fake)))
        losses['d_loss_rel'] = (torch.mean(
            torch.nn.ReLU()(1.0 - (refine_real - torch.mean(refine_fake)))) + torch.mean(
            torch.nn.ReLU()(1.0 + (refine_fake - torch.mean(refine_real))))) / 2

        ## G part
        l1 = nn.L1Loss()(x1 * (1. - masks), ground_truth * (1. - masks)) * self.config['coarse_l1_alpha'] \
             + nn.L1Loss()(x2 * (1. - masks), ground_truth * (1. - masks))
        ssim = ((1. - self.ssim(ground_truth, x1_inpaint)) + (1.0 - self.ssim(ground_truth, x2_inpaint))) / 2.0
        losses['l1'] = l1 * 0.75 + ssim * 0.25

        refine_real, refine_fake = self.dis_forward(self.globalD, ground_truth, x2_inpaint)
        losses['g_loss_loren'] = torch.mean(torch.log(1.0 + torch.abs(refine_fake - refine_real)))
        losses['g_loss_rel'] = (torch.mean(
            torch.nn.ReLU()(1.0 + (refine_real - torch.mean(refine_fake)))) + torch.mean(
            torch.nn.ReLU()(1.0 - (refine_fake - torch.mean(refine_real))))) / 2

        return losses, x1_inpaint, x2_inpaint

    def dis_forward(self, netD, ground_truth, x_inpaint):
        assert ground_truth.size() == x_inpaint.size()
        batch_size = ground_truth.size(0)
        batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
        batch_output = netD(batch_data)
        real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)
        #real_scores = Discriminator(DiffAugment(reals, policy=policy))
        real_pred = DiffAugment(real_pred, policy=policy)
        fake_pred = DiffAugment(fake_pred, policy=policy)
        return real_pred, fake_pred

    def save_model(self, checkpoint_dir):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(checkpoint_dir, 'gen.pt')
        global_dis_name = os.path.join(checkpoint_dir, 'global_dis.pt')
        gen_opt_name = os.path.join(checkpoint_dir, 'gen_optimizer.pt')
        torch.save(self.netG.state_dict(), gen_name)
        torch.save(self.globalD.state_dict(), global_dis_name)
        torch.save(self.optimizer_g.state_dict(), gen_opt_name)

    def resume(self, checkpoint_dir, iteration=1):
        g_checkpoint = torch.load(f'{checkpoint_dir}/gen.pt')
        global_dis_checkpoint = torch.load(f'{checkpoint_dir}/global_dis.pt')
        self.netG.load_state_dict(g_checkpoint, strict=False)
        self.globalD.load_state_dict(global_dis_checkpoint)
        print("Model loaded")
        return iteration



def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

In [None]:
#@title Training
%cd /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
!python train.py --config /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/configs/config.yaml

In [None]:
#@title Show files
%cd /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/checkpoints/Test/hole_v1
!ls

In [None]:
#@title [Warning] Deleting Training Files
%cd /content/
!sudo rm -rf /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/checkpoints

# Testing

In [None]:
#@title Fixing test_single.py
%%writefile /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/test_single.py
def mask_image(x, config):
    height, width, _ = config['image_shape']
    max_mask = x.shape[0]
    result = torch.ones_like(x)
    mask = torch.ones(size=[x.shape[0], 1, x.shape[2], x.shape[3]])
    for i in range(max_mask):
        # mask_temp = random_mask(height=height, width=width)
        mask_temp = brush_stroke_mask().generate_mask(height, width)
        mask_temp_tensor = torch.tensor(mask_temp, dtype=torch.float32)
        if x.is_cuda:
            mask_temp_tensor.cuda()
        result[i, :, :, :] = x[i, :, :, :] * (1. - mask_temp_tensor)
        mask[i, :, :, :] = mask[i, :, :, :] * mask_temp_tensor
    return result, mask


import os
import random
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.utils as vutils

from model.network import Generator
from utils.tools import get_config, is_image_file, default_loader, normalize


parser = ArgumentParser()
parser.add_argument('--config', type=str, default='configs/test.yaml',
                    help="training configuration")
parser.add_argument('--seed', type=int, default = '2019', help='manual seed')
parser.add_argument('--image', type=str, default='example/image/image.jpg')
parser.add_argument('--mask', type=str, default='example/mask/mask.png')
parser.add_argument('--output', type=str, default='output/output.png')
parser.add_argument('--flow', type=str, default='')
parser.add_argument('--checkpoint_path', type=str, default='')
parser.add_argument('--iter', type=int, default=0)

def main():
    args = parser.parse_args()
    config = get_config(args.config)

    # CUDA configuration
    cuda = config['cuda']
    device_ids = config['gpu_ids']
    if cuda:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in device_ids)
        device_ids = list(range(len(device_ids)))
        config['gpu_ids'] = device_ids
        cudnn.benchmark = True

    print("Arguments: {}".format(args))

    # Set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    print("Random seed: {}".format(args.seed))
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed_all(args.seed)

    print("Configuration: {}".format(config))

    try:  # for unexpected error logging
        with torch.no_grad():   # enter no grad context
            if is_image_file(args.image):
                if args.mask and is_image_file(args.mask):
                    # Test a single masked image with a given mask
                    x = default_loader(args.image)
                    mask = default_loader(args.mask)
                    x = transforms.Resize(config['image_shape'][:-1])(x)
                    # x = transforms.CenterCrop(config['image_shape'][:-1])(x)
                    mask = transforms.Resize(config['image_shape'][:-1])(mask)
                    # mask = transforms.CenterCrop(config['image_shape'][:-1])(mask)
                    x = transforms.ToTensor()(x)
                    mask = transforms.ToTensor()(mask)[0].unsqueeze(dim=0)
                    x = normalize(x)
                    x = x * (1. - mask)
                    x = x.unsqueeze(dim=0)
                    x_raw = x
                    mask = mask.unsqueeze(dim=0)
                elif args.mask:
                    raise TypeError("{} is not an image file.".format(args.mask))
                else:
                    # Test a single ground-truth image with a random mask
                    ground_truth = default_loader(args.image)
                    ground_truth = transforms.ToTensor()(ground_truth)
                    ground_truth = normalize(ground_truth)
                    ground_truth = ground_truth.unsqueeze(dim=0)
                    bboxes = test_bbox(config, batch_size=ground_truth.size(0), t = 50, l = 60)
                    x, mask = mask_image(ground_truth, bboxes, config)

                # Set checkpoint path
                if not args.checkpoint_path:
                    checkpoint_path = os.path.join('checkpoints',
                                                   config['dataset_name'],
                                                   config['mask_type'] + '_' + config['expname'])
                else:
                    checkpoint_path = args.checkpoint_path

                # Define the trainer
                netG = Generator(config['netG'], cuda, device_ids)
                # Resume weight
                g_checkpoint = torch.load(f'{checkpoint_path}/gen.pt')
                netG.load_state_dict(g_checkpoint, strict=False)
                # model_iteration = int(last_model_name[-11:-3])
                print("Model Resumed".format(checkpoint_path))

                if cuda:
                    netG = nn.parallel.DataParallel(netG, device_ids=device_ids)
                    x = x.cuda()
                    mask = mask.cuda()

                # Inference
                x1, x2 = netG(x, mask)
                inpainted_result = x2 * mask + x * (1. - mask)

                vutils.save_image(inpainted_result, args.output, padding=0, normalize=True)
                print("Saved the inpainted result to {}".format(args.output))
            else:
                raise TypeError("{} is not an image file.".format)
        # exit no grad context
    except Exception as e:  # for unexpected error logging
        print("Error: {}".format(e))
        raise e


if __name__ == '__main__':
    main()

In [None]:
#@title Test model
%cd /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
!python test_single.py --config /content/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting/configs/config.yaml \
--image /content/0.jpg --mask /content/mask.png --output /content/output.png