# Colab-RN
Original repo: [geekyutao/RN](https://github.com/geekyutao/RN)

Differentiable Augmentation: [mit-han-lab/data-efficient-gans](https://github.com/mit-han-lab/data-efficient-gans)

My fork: [styler00dollar/Colab-RN](https://github.com/styler00dollar/Colab-RN)

In [None]:
!nvidia-smi

In [None]:
#@title git clone and install
%cd /content/
!git clone https://github.com/geekyutao/RN
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh
!chmod +x Miniconda3-4.5.4-Linux-x86_64.sh
!bash ./Miniconda3-4.5.4-Linux-x86_64.sh -b -f -p /usr/local
!conda install pytorch==1.1 cudatoolkit torchvision -c pytorch -y
!sudo apt-get install imagemagick imagemagick-doc
!pip install scipy==1.1
!pip install tensorboardX
!pip install scikit-image
!pip install opencv-python
!pip install torchvision

In [None]:
#@title Differentiable Augmentation (experimental)
%%writefile /content/RN/main.py

# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}


policy = 'color,translation,cutout'

from __future__ import print_function
import argparse
from math import log10
import numpy as np

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.utils as vutils

from module_util import *
from dataset import build_dataloader
import pdb
import socket
import time
from skimage import io
from skimage.measure import compare_psnr

from models import InpaintingModel

from tensorboardX import SummaryWriter


# Training settings
parser = argparse.ArgumentParser(description='Region Normalization for Image Inpainting')
parser.add_argument('--bs', type=int, default=14, help='training batch size')
parser.add_argument('--input_size', type=int, default=256, help='input image size')
parser.add_argument('--start_epoch', type=int, default=1, help='Starting epoch for continuing training')
parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--snapshots', type=int, default=1, help='Snapshots')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001')
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--threads', type=int, default=2, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=67454, help='random seed to use. Default=123')
parser.add_argument('--gpus', default=1, type=int, help='number of gpu')
parser.add_argument('--img_flist', type=str, default='shuffled_train.flist')
parser.add_argument('--mask_flist', type=str, default='all.flist')
parser.add_argument('--model_type', type=str, default='RN')
parser.add_argument('--threshold', type=float, default=0.8)
parser.add_argument('--pretrained_sr', default='../weights/xx.pth', help='pretrained base model')
parser.add_argument('--pretrained', type=bool, default=False)
parser.add_argument('--save_folder', default='/data/yutao/Project/weights/', help='Location to save checkpoint models')
parser.add_argument('--prefix', default='0p1GAN0p8thre', help='Location to save checkpoint models')
parser.add_argument('--print_interval', type=int, default=100, help='how many steps to print the results out')
parser.add_argument('--render_interval', type=int, default=10000, help='how many steps to save a checkpoint')
parser.add_argument('--l1_weight', type=float, default=1.0)
parser.add_argument('--gan_weight', type=float, default=0.1)
parser.add_argument('--update_weight_interval', type=int, default=5000, help='how many steps to update losses weighing')
parser.add_argument('--with_test', default=False, action='store_true', help='Train with testing?')
parser.add_argument('--test', default=False, action='store_true', help='Test model')
parser.add_argument('--test_mask_flist', type=str, default='mask1k.flist')
parser.add_argument('--test_img_flist', type=str, default='val1k.flist')
parser.add_argument('--tb', default=False, action='store_true', help='Use tensorboardX?')

opt = parser.parse_args()
gpus_list = list(range(opt.gpus))  # the list of gpu
hostname = str(socket.gethostname())
opt.save_folder += opt.prefix
cudnn.benchmark = True
if not os.path.exists(opt.save_folder):
    os.makedirs(opt.save_folder)
print(opt)


def train(epoch):
    iteration, avg_g_loss, avg_d_loss, avg_l1_loss, avg_gan_loss = 0, 0, 0, 0, 0
    last_l1_loss, last_gan_loss, cur_l1_loss, cur_gan_loss = 0, 0, 0, 0
    model.train()
    t0 = time.time()
    t_io1 = time.time()
    for batch in training_data_loader:
        gt, mask, index = batch
        t_io2 = time.time()
        if cuda:
            gt = gt.cuda()
            mask = mask.cuda()

        prediction = model.generator(gt, mask)
        merged_result = prediction * mask + gt * (1 - mask)
        # render(epoch, iteration, mask, prediction.detach(), gt)
        # os._exit()

        # Compute Loss
        g_loss, d_loss = 0, 0

        d_real, _ = model.discriminator(gt)
        d_fake, _ = model.discriminator(prediction.detach())

        d_real = DiffAugment(d_real, policy=policy)
        d_fake = DiffAugment(d_fake, policy=policy)

        d_real_loss = model.adversarial_loss(d_real, True, True)
        d_fake_loss = model.adversarial_loss(d_fake, False, True)
        d_loss += (d_real_loss + d_fake_loss) / 2

        g_fake, _ = model.discriminator(prediction)
        g_fake = DiffAugment(g_fake, policy=policy)

        g_gan_loss = model.adversarial_loss(g_fake, True, False)
        g_loss += model.gan_weight * g_gan_loss
        g_l1_loss = model.l1_loss(gt, merged_result) / torch.mean(mask)
        # g_l1_loss = model.l1_loss(gt, prediction) / torch.mean(mask)
        g_loss += model.l1_weight * g_l1_loss

        # Record
        cur_l1_loss += g_l1_loss.data.item()
        cur_gan_loss += g_gan_loss.data.item()
        avg_l1_loss += g_l1_loss.data.item()
        avg_gan_loss += g_gan_loss.data.item()
        avg_g_loss += g_loss.data.item()
        avg_d_loss += d_loss.data.item()

        # Backward
        d_loss.backward()
        model.dis_optimizer.step()
        model.dis_optimizer.zero_grad()

        g_loss.backward()
        model.gen_optimizer.step()
        model.gen_optimizer.zero_grad()

        model.global_iter += 1
        iteration += 1
        t1 = time.time()
        td, t0 = t1 - t0, t1

        if iteration % opt.print_interval == 0:
            print("=> Epoch[{}]({}/{}): Avg L1 loss: {:.6f} | G loss: {:.6f} | Avg D loss: {:.6f} || Timer: {:.4f} sec. | IO: {:.4f}".format(
                epoch, iteration, len(training_data_loader), avg_l1_loss/opt.print_interval, avg_g_loss/opt.print_interval, avg_d_loss/opt.print_interval, td, t_io2-t_io1), flush=True)
            #print("=> Epoch[{}]({}/{}): Avg G loss: {:.6f} || Timer: {:.4f} sec. || IO: {:.4f}".format(
            #    epoch, iteration, len(training_data_loader), avg_g_loss/opt.print_interval, td, t_io2-t_io1), flush=True)

            if opt.tb:
                writer.add_scalar('scalar/G_loss', avg_g_loss/opt.print_interval, model.global_iter)
                writer.add_scalar('scalar/D_loss', avg_d_loss/opt.print_interval, model.global_iter)
                writer.add_scalar('scalar/G_l1_loss', avg_l1_loss/opt.print_interval, model.global_iter)
                writer.add_scalar('scalar/G_gan_loss', avg_gan_loss/opt.print_interval, model.global_iter)

            avg_g_loss, avg_d_loss, avg_l1_loss, avg_gan_loss = 0, 0, 0, 0
        t_io1 = time.time()

        if iteration % opt.render_interval == 0:
            render(epoch, iteration, mask, merged_result.detach(), gt)
            if opt.with_test:
                print("Testing 1000 images...")
                test_psnr = test(model, test_data_loader)
                if opt.tb:
                    writer.add_scalar('scalar/test_PSNR', test_psnr, model.global_iter)
                    print("PSNR: ", test_psnr)

        # if iteration % opt.update_weight_interval == 0:
        #     if last_l1_loss == 0:
        #         last_l1_loss, last_gan_loss = cur_l1_loss, cur_gan_loss
        #     weights = dynamic_weigh([last_l1_loss, last_gan_loss], [cur_l1_loss, cur_gan_loss], T=1)
        #     model.l1_weight, model.gan_weight = weights[0], weights[1]
        #     print("===> losses weights changing: [l1, gan] = {:.4f}, {:.4f}".format(model.l1_weight, model.gan_weight))
        #     last_l1_loss, last_gan_loss = cur_l1_loss, cur_gan_loss



def dynamic_weigh(last_losses, cur_losses, T=20):
    # input lists
    last_losses, cur_losses = torch.Tensor(last_losses), torch.Tensor(cur_losses)
    w = torch.exp((cur_losses / last_losses) / T)
    return (last_losses.size(0) * w / torch.sum(w)).cuda()

def render(epoch, iter, mask, output, gt):

    name_pre = 'render/'+str(epoch)+'_'+str(iter)+'_'

    # input: (bs,3,256,256)
    input = gt * (1 - mask) + mask
    input = input[0].permute(1,2,0).cpu().numpy()
    io.imsave(name_pre+'input.png', (input*255).astype(np.uint8))

    # mask: (bs,1,256,256)
    mask = mask[0,0].cpu().numpy()
    io.imsave(name_pre+'mask.png', (mask*255).astype(np.uint8))

    # output: (bs,3,256,256)
    output = output[0].permute(1,2,0).cpu().numpy()
    io.imsave(name_pre+'output.png', (output*255).astype(np.uint8))

    # gt: (bs,3,256,256)
    gt = gt[0].permute(1,2,0).cpu().numpy()
    io.imsave(name_pre+'gt.png', (gt*255).astype(np.uint8))

def test(gen, dataloader):
    model = gen.eval()
    psnr = 0
    count = 0
    for batch in dataloader:
        gt_batch, mask_batch, index = batch
        if cuda:
            gt_batch = gt_batch.cuda()
            mask_batch = mask_batch.cuda()
        with torch.no_grad():
            pred_batch = model.generator(gt_batch, mask_batch)
        for i in range(gt_batch.size(0)):
            gt, pred = gt_batch[i], pred_batch[i]
            psnr += compare_psnr(pred.permute(1,2,0).cpu().numpy(), gt.permute(1,2,0).cpu().numpy(),\
            data_range=1)
            count += 1
    return psnr / count

def checkpoint(epoch):
    model_out_path = opt.save_folder+'/'+'x_'+hostname + \
        opt.model_type+"_"+opt.prefix + "_bs_{}_epoch_{}.pth".format(opt.bs, epoch)
    torch.save(model.state_dict(), model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

if __name__ == '__main__':
    if opt.tb:
        writer = SummaryWriter()

    # Set the GPU mode
    cuda = opt.gpu_mode
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    # Set the random seed
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed_all(opt.seed)

    # Model
    model = InpaintingModel(g_lr=opt.lr, d_lr=(0.1 * opt.lr), l1_weight=opt.l1_weight, gan_weight=opt.gan_weight, iter=0, threshold=opt.threshold)
    print('---------- Networks architecture -------------')
    print("Generator:")
    print_network(model.generator)
    print("Discriminator:")
    print_network(model.discriminator)
    print('----------------------------------------------')
    initialize_weights(model, scale=0.1)

    if cuda:
        model = model.cuda()
        if opt.gpus > 1:
            model.generator = torch.nn.DataParallel(model.generator, device_ids=gpus_list)
            model.discriminator = torch.nn.DataParallel(model.discriminator, device_ids=gpus_list)

    # Load the pretrain model.
    if opt.pretrained:
        model_name = os.path.join(opt.pretrained_sr)
        print('pretrained model: %s' % model_name)
        if os.path.exists(model_name):
            pretained_model = torch.load(model_name, map_location=lambda storage, loc: storage)
            model.load_state_dict(pretained_model)
            print('Pre-trained model is loaded.')
            print(' Current: G learning rate:', model.g_lr, ' | L1 loss weight:', model.l1_weight, \
            ' | GAN loss weight:', model.gan_weight)

    # Datasets
    print('===> Loading datasets')
    training_data_loader = build_dataloader(
        flist=opt.img_flist,
        mask_flist=opt.mask_flist,
        augment=True,
        training=True,
        input_size=opt.input_size,
        batch_size=opt.bs,
        num_workers=opt.threads,
        shuffle=True
    )
    print('===> Loaded datasets')

    if opt.test or opt.with_test:
        test_data_loader = build_dataloader(
            flist=opt.test_img_flist,
            mask_flist=opt.test_mask_flist,
            augment=False,
            training=False,
            input_size=opt.input_size,
            batch_size=64,
            num_workers=opt.threads,
            shuffle=False
        )
        print('===> Loaded test datasets')

    if opt.test:
        test_psnr = test(model, test_data_loader)
        os._exit(0)

    # Start training
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):

        train(epoch)

        count = (epoch-1)
        if isinstance(model, torch.nn.DataParallel):
            model = model.module
        for param_group in model.gen_optimizer.param_groups:
            param_group['lr'] = model.g_lr * (0.8 ** count)
            print('===> Current G learning rate: ', param_group['lr'])
        for param_group in model.dis_optimizer.param_groups:
            param_group['lr'] = model.d_lr * (0.8 ** count)
            print('===> Current D learning rate: ', param_group['lr'])

        if (epoch+1) % (opt.snapshots) == 0:
            checkpoint(epoch)
if opt.tb:
    writer.close()


Training

In [None]:
#@title training
%cd /content/RN
!python main.py --bs 1 --gpus 1 --prefix rn --img_flist /content/train/train.tflist --mask_flist /content/mask_train/mask_train.tflist

In [None]:
#@title adding image output to eval.py
%%writefile /content/RN/eval.py
from __future__ import print_function
import argparse
from math import log10
import numpy as np
import math

import torchvision
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.utils as vutils

from module_util import initialize_weights
from dataset import build_dataloader
import pdb
import socket
import time
import skimage
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr

from models import InpaintingModel
import cv2

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Video Inpainting with Background Auxilary')
parser.add_argument('--bs', type=int, default=64, help='training batch size')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001')
parser.add_argument('--cpu', default=False, action='store_true', help='Use CPU to test')
parser.add_argument('--threads', type=int, default=1, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=67454, help='random seed to use. Default=123')
parser.add_argument('--gpus', default=1, type=int, help='number of gpu')
parser.add_argument('--threshold', type=float, default=0.8)
parser.add_argument('--img_flist', type=str, default='/data/dataset/places2/flist/val.flist')
parser.add_argument('--mask_flist', type=str, default='/data/dataset/places2/flist/3w_all.flist')
parser.add_argument('--model', default='/data/yutao/Project/weights/BGNet/x_admin.cluster.localRN-0.8BGNet_bs_14_epoch_9.pth', help='sr pretrained base model')
parser.add_argument('--save', default=False, action='store_true', help='If save test images')
parser.add_argument('--save_path', type=str, default='./test_results')
parser.add_argument('--input_size', type=int, default=256, help='input image size')
parser.add_argument('--l1_weight', type=float, default=1.0)
parser.add_argument('--gan_weight', type=float, default=0.1)


opt = parser.parse_args()


def eval():
    model.eval()
    model.generator.eval()
    count = 1
    avg_du = 0
    avg_psnr, avg_ssim, avg_l1 = 0., 0., 0.
    counter = 0
    for batch in testing_data_loader:
        gt, mask, index = batch
        t_io2 = time.time()
        if cuda:
            gt = gt.cuda()
            mask = mask.cuda()


        ## The test or ensemble test

        # t0 = time.clock()
        with torch.no_grad():
            prediction = model.generator(gt, mask)
            prediction = prediction * mask + gt * (1 - mask)

        counter += 1
        filename = "output_" + str(counter) + ".png"
        torchvision.utils.save_image(prediction, filename, nrow=4)

        # t1 = time.clock()
        # du = t1 - t0
        # print("===> Processing: %s || Timer: %.4f sec." % (str(count), du))

        # avg_du += du
        # print(
        #     "Number: %05d" % (count),
        #     " | Average time: %.4f" % (avg_du/count))

        # Save the video frames
        batch_avg_psnr, batch_avg_ssim, batch_avg_l1 = evaluate_batch(
            batch_size=opt.bs,
            gt_batch=gt,
            pred_batch=prediction,
            mask_batch=mask,
            save=opt.save,
            path=opt.save_path,
            count=count,
            index=index
            )

        # avg_psnr = (avg_psnr * (count - 1) + batch_avg_psnr) / count
        avg_psnr = avg_psnr + ((batch_avg_psnr- avg_psnr) / count)
        avg_ssim = avg_ssim + ((batch_avg_ssim- avg_ssim) / count)
        avg_l1 = avg_l1 + ((batch_avg_l1- avg_l1) / count)

        print(
            "Number: %05d" % (count * opt.bs),
            " | Average: PSNR: %.4f" % (avg_psnr),
            " SSIM: %.4f" % (avg_ssim),
            " L1: %.4f" % (avg_l1),
            "| Current batch:", count,
            " PSNR: %.4f" % (batch_avg_psnr),
            " SSIM: %.4f" % (batch_avg_ssim),
            " L1: %.4f" % (batch_avg_l1), flush=True
        )

        count+=1




def save_img(path, name, img):
    # img (H,W,C) or (H,W) np.uint8
    skimage.io.imsave(path+'/'+name+'.png', img)

def PSNR(pred, gt, shave_border=0):
    return compare_psnr(pred, gt, data_range=255)
    # imdff = pred - gt
    # rmse = math.sqrt(np.mean(imdff ** 2))
    # if rmse == 0:
    #     return 100
    # return 20 * math.log10(255.0 / rmse)

def L1(pred, gt):
    return np.mean(np.abs((np.mean(pred,2) - np.mean(gt,2))/255))

def SSIM(pred, gt, data_range=255, win_size=11, multichannel=True):
    return compare_ssim(pred, gt, data_range=data_range, \
    multichannel=multichannel, win_size=win_size)

def evaluate_batch(batch_size, gt_batch, pred_batch, mask_batch, save=False, path=None, count=None, index=None):
    pred_batch = pred_batch * mask_batch + gt_batch * (1 - mask_batch)

    if save:
        input_batch = gt_batch * (1 - mask_batch) + mask_batch
        input_batch = (input_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8)
        mask_batch = (mask_batch.detach().permute(0,2,3,1).cpu().numpy()[:,:,:,0]*255).astype(np.uint8)

        if not os.path.exists(path):
            os.mkdir(path)


    gt_batch = (gt_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8)
    pred_batch = (pred_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8)

    psnr, ssim, l1 = 0., 0., 0.
    for i in range(batch_size):
        gt, pred, name = gt_batch[i], pred_batch[i], index[i].data.item()

        psnr += PSNR(pred, gt)
        ssim += SSIM(pred, gt)
        l1 += L1(pred, gt)

        if save:
            save_img(path, str(count)+'_'+str(name)+'_input', input_batch[i])
            save_img(path, str(count)+'_'+str(name)+'_mask', mask_batch[i])
            save_img(path, str(count)+'_'+str(name)+'_output', pred_batch[i])
            save_img(path, str(count)+'_'+str(name)+'_gt', gt_batch[i])

    return psnr/batch_size, ssim/batch_size, l1/batch_size



def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


if __name__ == '__main__':
    if opt.cpu:
        print("===== Use CPU to Test! =====")
    else:
        print("===== Use GPU to Test! =====")

    ## Set the GPU mode
    gpus_list=range(opt.gpus)
    cuda = not opt.cpu
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")


    # Model
    model = InpaintingModel(g_lr=opt.lr, d_lr=(0.1 * opt.lr), l1_weight=opt.l1_weight, gan_weight=opt.gan_weight, iter=0, threshold=opt.threshold)
    print('---------- Networks architecture -------------')
    print("Generator:")
    print_network(model.generator)
    print("Discriminator:")
    print_network(model.discriminator)
    print('----------------------------------------------')

    pretained_model = torch.load(opt.model, map_location=lambda storage, loc: storage)

    if cuda:
        model = model.cuda()
        model.generator = torch.nn.DataParallel(model.generator, device_ids=gpus_list)
        model.discriminator = torch.nn.DataParallel(model.discriminator, device_ids=gpus_list)
        model.load_state_dict(pretained_model)
    else:
        new_state_dict = model.state_dict()
        for k, v in pretained_model.items():
            k = k.replace('module.', '')
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
        

    # pretained_G_model = torch.load(opt.model, map_location=lambda storage, loc: storage)
    # model.generator.load_state_dict(pretained_G_model)
    print('Pre-trained G model is loaded.')

    # Datasets
    print('===> Loading datasets')
    testing_data_loader = build_dataloader(
        flist=opt.img_flist,
        mask_flist=opt.mask_flist,
        augment=False,
        training=False,
        input_size=opt.input_size,
        batch_size=opt.bs,
        num_workers=opt.threads,
        shuffle=False
    )
    print('===> Loaded datasets')

    ## Eval Start!!!!
    eval()


Testing


Sidenote: Uses black to mark areas instead of white.

In [None]:
#@title Image and mask need to be dividable by 4, this code does fix wrong images 
import cv2
import numpy
path_inpainting = '/content/val/0.jpg'
path_mask = '/content/mask_val/0.png'
image=cv2.imread(path_mask)
image_size0 = numpy.floor(image.shape[0]/4)
image_size1 = numpy.floor(image.shape[1]/4)
image=cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
ret,image=cv2.threshold(image,254,255,cv2.THRESH_BINARY)
image = cv2.resize(image, (int(image_size1*4), int(image_size0*4)), cv2.INTER_NEAREST)
cv2.imwrite(path_mask, image)

image=cv2.imread(path_inpainting)
image = cv2.resize(image, (int(image_size1*4), int(image_size0*4)), cv2.INTER_NEAREST)
cv2.imwrite(path_inpainting, image)

!convert /content/mask_val/0.png -channel RGB -negate /content/mask_val/0.png

In [None]:
!python eval.py --bs 1 --model /content/RN/pretrained_model/x_admin.cluster.localRN-0.8RN-Net_bs_14_epoch_3.pth \
--img_flist /content/val/val.tflist --mask_flist /content/mask_val/mask_val.tflist 

In [None]:
#@title removing alpha from image (Example: val/0.jpg)
import cv2
filename = '/content/val/0.jpg'
image = cv2.imread(filename)
cv2.imwrite(filename, image)