# Partial Convolutions for Image Inpainting using PyTorch

This is a PyTorch implementation of "*Image Inpainting for Irregular Holes Using Partial Convolutions*", https://arxiv.org/abs/1804.07723 by Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao and Bryan Catanzaro from NVIDIA. 

## Imports

In [None]:
!pip install oyaml

In [None]:
from torchvision import transforms, utils
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
from torchvision import models
from random import randint
import torch.nn as nn
from PIL import Image
from glob import glob
import oyaml as yaml
import numpy as np
import datetime
import random
import torch
import cv2
import os

## Generating random masks

In the paper a technique based on occlusion/dis-occlusion between two consecutive frames in videos for creating random irregular masks is used. We've instead just made a simple mask-generator function which uses OpenCV to draw some random irregular shapes which are used for masks.

In [None]:
class MaskGenerator(object):

    def __init__(self, height, width, channels=3,
                 filepath=None):
        """Function for generating masks
        Arguments:
            height {int} -- Mask height
            width {width} -- Mask width
        Keyword Arguments:
            channels {int} -- Channels to output (default: {3})
            filepath {[type]} -- Load masks from filepath. If None, generate masks with OpenCV (default: {None})
        """

        self.height = height
        self.width = width
        self.channels = channels
        self.filepath = filepath

        # If filepath supplied, load the list of masks within the directory
        self.mask_files = []
        if self.filepath:
            filenames = [f for f in os.listdir(self.filepath)]
            self.mask_files = [f for f in filenames
                               if any(filetype in f.lower()
                                      for filetype
                                      in ['.jpeg', '.png', '.jpg'])]
            print("Found {} masks in {}".format(len(self.mask_files),
                                                   self.filepath))

    def _generate_mask(self):

        img = np.zeros((self.height, self.width, self.channels), np.uint8)

        # Set size scale
        size = int((self.width + self.height) * 0.03)
        if self.width < 64 or self.height < 64:
            raise Exception("Width and Height of mask must be at least 64!")

        # Draw random lines
        for _ in range(randint(1, 20)):
            x1, x2 = randint(1, self.width), randint(1, self.width)
            y1, y2 = randint(1, self.height), randint(1, self.height)
            thickness = randint(3, size)
            cv2.line(img, (x1, y1), (x2, y2), (1, 1, 1), thickness)

        # Draw random circles
        for _ in range(randint(1, 20)):
            x1, y1 = randint(1, self.width), randint(1, self.height)
            radius = randint(3, size)
            cv2.circle(img, (x1, y1), radius, (1, 1, 1), -1)

        # Draw random ellipses
        for _ in range(randint(1, 20)):
            x1, y1 = randint(1, self.width), randint(1, self.height)
            s1, s2 = randint(1, self.width), randint(1, self.height)
            a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
            thickness = randint(3, size)
            cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3,
                        (1, 1, 1), thickness)

        return 1 - img

    def _load_mask(self, rotation=True, dilation=True, cropping=True):

        # Read image
        mask = cv2.imread(os.path.join(self.filepath, np.random.choice(
                                                        self.mask_files,
                                                        1,
                                                        replace=False
                                                        )[0]))

        # Random rotation
        if rotation:
            rand = np.random.randint(-180, 180)
            M = cv2.getRotationMatrix2D((mask.shape[1]/2, mask.shape[0]/2),
                                        rand, 1.5)
            mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0]))

        # Random dilation
        if dilation:
            rand = np.random.randint(5, 47)
            kernel = np.ones((rand, rand), np.uint8)
            mask = cv2.erode(mask, kernel, iterations=1)

        # Random cropping
        if cropping:
            x = np.random.randint(0, mask.shape[1] - self.width)
            y = np.random.randint(0, mask.shape[0] - self.height)
            mask = mask[y:y+self.height, x:x+self.width]

        return (mask > 1).astype(np.uint8)

    def sample(self):
        return self._generate_mask()

def main():
  NUM_MASK = 19000
  
  DIR_NAME = 'val_mask'
  if os.path.exists(DIR_NAME):
    pass
  else:
    os.mkdir(DIR_NAME)
  
  mask_generator = MaskGenerator(256, 256, channels=3,filepath=None)

  for idx in range(NUM_MASK):
      mask = mask_generator.sample() * 255
      cv2.imwrite('{}/{}.png'.format(DIR_NAME, idx), mask)

if __name__ == '__main__':
    main()

## Loading the train and validation datasets along with the masks

In [None]:
class InitDataset(Dataset):
    def __init__(self, data_root, img_transform, mask_transform, data='train'):
        super(InitDataset, self).__init__()
        self.img_transform = img_transform
        self.mask_transform = mask_transform

        if data == 'train':
            self.paths = glob('{}/train/**/*.jpg'.format(data_root),
                              recursive=True)
            self.mask_paths = glob('{}/mask/*.png'.format(data_root))
        else:
            self.paths = glob('{}/val/*.jpg'.format(data_root, data))
            self.mask_paths = glob('{}/val_mask/*.png'.format(data_root))

        self.N_mask = len(self.mask_paths)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = self._load_img(self.paths[index])
        img = self.img_transform(img.convert('RGB'))
        mask = Image.open(self.mask_paths[random.randint(0, self.N_mask - 1)])
        mask = self.mask_transform(mask.convert('RGB'))
        return img * mask, mask, img

    def _load_img(self, path):
        """
        For dealing with the error of loading image which is occured by the loaded image has no data.
        """
        try:
            img = Image.open(path)
        except:
            extension = path.split('.')[-1]
            for i in range(10):
                new_path = path.split('.')[0][:-1] + str(i) + '.' + extension
                try:
                    img = Image.open(new_path)
                    break
                except:
                    continue
        return img


## Defining the model

## Partial Convolution Layer
The key element here is ofcourse the partial convolutional layer. Basically, given the convolutional filter **W** and the corresponding bias *b*, the following partial convolution is applied instead of a normal convolution:

<img src='https://raw.githubusercontent.com/MathiasGruber/PConv-Keras/master/data/images/eq1.PNG' />

where ⊙ is element-wise multiplication and **M** is a binary mask of 0s and 1s. Importantly, after each partial convolution, the mask is also updated, so that if the convolution was able to condition its output on at least one valid input, then the mask is removed at that location, i.e.

<img src='https://raw.githubusercontent.com/MathiasGruber/PConv-Keras/master/data/images/eq2.PNG' />

The result of this is that with a sufficiently deep network, the mask will eventually be all ones (i.e. disappear)

## UNet Architecture
The architechture essentially it's based on a UNet-like structure, where all normal convolutional layers are replace with partial convolutional layers, such that in all cases the image is passed through the network alongside the mask.

<img src='https://raw.githubusercontent.com/MathiasGruber/PConv-Keras/master/data/images/architecture.png' />

In [None]:
class PartialConvolution(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros'):
        super(PartialConvolution, self).__init__(in_channels, out_channels,
                                            kernel_size, stride=stride,
                                            padding=padding, dilation=dilation,
                                            groups=groups, bias=bias,
                                            padding_mode=padding_mode)
        # kernel for updating mask
        self.mask_kernel = torch.ones(self.out_channels, self.in_channels,
                                      self.kernel_size[0], self.kernel_size[1])
        # sum1 for renormalization
        self.sum1 = self.mask_kernel.shape[1] * self.mask_kernel.shape[2] \
                                              * self.mask_kernel.shape[3]
        # Define the updated mask
        self.update_mask = None
        # Define the mask ratio (sum(1) / sum(M))
        self.mask_ratio = None
        # Initialize the weights for image convolution
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, img, mask):
        with torch.no_grad():
            if self.mask_kernel.type() != img.type():
                self.mask_kernel = self.mask_kernel.to(img)
            # Create the updated mask
            # for calcurating mask ratio (sum(1) / sum(M))
            self.update_mask = F.conv2d(mask, self.mask_kernel,
                                        bias=None, stride=self.stride,
                                        padding=self.padding,
                                        dilation=self.dilation,
                                        groups=1)
            # calculate mask ratio (sum(1) / sum(M))
            self.mask_ratio = self.sum1 / (self.update_mask + 1e-8)
            self.update_mask = torch.clamp(self.update_mask, 0, 1)
            self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        # calcurate WT . (X * M)
        conved = torch.mul(img, mask)
        conved = F.conv2d(conved, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)

        if self.bias is not None:
            # Maltuply WT . (X * M) and sum(1) / sum(M) and Add the bias
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(conved - bias_view, self.mask_ratio) + bias_view
            # The masked part pixel is updated to 0
            output = torch.mul(output, self.mask_ratio)
        else:
            # Multiply WT . (X * M) and sum(1) / sum(M)
            output = torch.mul(conved, self.mask_ratio)

        return output, self.update_mask


class UpsampleData(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, dec_feature, enc_feature, dec_mask, enc_mask):
        out = self.upsample(dec_feature)
        out = torch.cat([out, enc_feature], dim=1)
        out_mask = self.upsample(dec_mask)
        out_mask = torch.cat([out_mask, enc_mask], dim=1)
        return out, out_mask


class PConvLayer(nn.Module):
    def __init__(self, in_ch, out_ch, sample='none-3', dec=False,
                 bn=True, active='relu', conv_bias=False):
        super().__init__()
        # Define the partial conv layer
        if sample == 'down-7':
            params = {"kernel_size": 7, "stride": 2, "padding": 3}
        elif sample == 'down-5':
            params = {"kernel_size": 5, "stride": 2, "padding": 2}
        elif sample == 'down-3':
            params = {"kernel_size": 3, "stride": 2, "padding": 1}
        else:
            params = {"kernel_size": 3, "stride": 1, "padding": 1}
        self.conv = PartialConvolution(in_ch, out_ch,
                                  params["kernel_size"],
                                  params["stride"],
                                  params["padding"],
                                  bias=conv_bias)

        
        if dec:
            self.upcat = UpsampleData()
        if bn:
            bn = nn.BatchNorm2d(out_ch)
        if active == 'relu':
            self.activation = nn.ReLU()
        elif active == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, img, mask, enc_img=None, enc_mask=None):
        if hasattr(self, 'upcat'):
            out, update_mask = self.upcat(img, enc_img, mask, enc_mask)
            out, update_mask = self.conv(out, update_mask)
        else:
            out, update_mask = self.conv(img, mask)
        if hasattr(self, 'bn'):
            out = self.bn(out)
        if hasattr(self, 'activation'):
            out = self.activation(out)
        return out, update_mask


class PConvUNet(nn.Module):
    def __init__(self, finetune=False, in_ch=3, layer_size=6):
        super().__init__()
        self.freeze_enc_bn = True if finetune else False
        self.layer_size = layer_size

        self.enc_1 = PConvLayer(in_ch, 64, 'down-7', bn=False)
        self.enc_2 = PConvLayer(64, 128, 'down-5')
        self.enc_3 = PConvLayer(128, 256, 'down-5')
        self.enc_4 = PConvLayer(256, 512, 'down-3')
        self.enc_5 = PConvLayer(512, 512, 'down-3')
        self.enc_6 = PConvLayer(512, 512, 'down-3')
        self.enc_7 = PConvLayer(512, 512, 'down-3')
        self.enc_8 = PConvLayer(512, 512, 'down-3')

        self.dec_8 = PConvLayer(512 + 512, 512, dec=True, active='leaky')
        self.dec_7 = PConvLayer(512 + 512, 512, dec=True, active='leaky')
        self.dec_6 = PConvLayer(512 + 512, 512, dec=True, active='leaky')
        self.dec_5 = PConvLayer(512 + 512, 512, dec=True, active='leaky')
        self.dec_4 = PConvLayer(512 + 256, 256, dec=True, active='leaky')
        self.dec_3 = PConvLayer(256 + 128, 128, dec=True, active='leaky')
        self.dec_2 = PConvLayer(128 + 64,   64, dec=True, active='leaky')
        self.dec_1 = PConvLayer(64 + 3,      3, dec=True, bn=False,
                                active=None, conv_bias=True)

    def forward(self, img, mask):
        enc_f, enc_m = [img], [mask]
        for layer_num in range(1, self.layer_size+1):
            if layer_num == 1:
                feature, update_mask = \
                    getattr(self, 'enc_{}'.format(layer_num))(img, mask)
            else:
                enc_f.append(feature)
                enc_m.append(update_mask)
                feature, update_mask = \
                    getattr(self, 'enc_{}'.format(layer_num))(feature,
                                                              update_mask)

        assert len(enc_f) == self.layer_size

        for layer_num in reversed(range(1, self.layer_size+1)):
            feature, update_mask = getattr(self, 'dec_{}'.format(layer_num))(
                    feature, update_mask, enc_f.pop(), enc_m.pop())

        return feature, mask

    def train(self, mode=True):
        super().train(mode)
        if not self.freeze_enc_bn:
            return 
        for name, module in self.named_modules():
            if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
                module.eval()

def main():
  size = (1, 3, 512, 512)
  img = torch.ones(size)
  mask = torch.ones(size)
  mask[:, :, 128:-128, :][:, :, :, 128:-128] = 0

  conv = PartialConvolution(3, 3, 3, 1, 1)
  criterion = nn.L1Loss()
  img.requires_grad = True

  output, out_mask = conv(img, mask)
  loss = criterion(output, torch.randn(size))
  loss.backward()

  assert (torch.sum(torch.isnan(conv.weight.grad)).item() == 0)
  assert (torch.sum(torch.isnan(conv.bias.grad)).item() == 0)

  model = PConvUNet()
  before = model.enc_5.conv.weight[0][0]
  print(before)
  output, out_mask = model(img, mask)

if __name__ == '__main__':
  main()

## Defining the loss function

This technique uses quite an intense loss function. The highlights of it are:

* Per-pixel losses both for maskes and un-masked regions
* Perceptual loss based on ImageNet pre-trained VGG-16 (*pool1, pool2 and pool3 layers*)
* Style loss on VGG-16 features both for predicted image and for computed image (non-hole pixel set to ground truth)
* Total variation loss for a 1-pixel dilation of the hole region

The weighting of all these loss terms are as follows:
<img src='https://raw.githubusercontent.com/MathiasGruber/PConv-Keras/master/data/images/eq7.PNG' />

### VGG16 model for feature extraction
The authors of the paper used PyTorch to implement the model. The VGG16 model was chosen for feature extraction. The [VGG16 model in PyTorch](https://pytorch.org/docs/stable/torchvision/models.html) was trained with the following image pre-processing:
1. Divide the image by 255,
2. Subtract [0.485, 0.456, 0.406] from the RGB channels, respectively,
3. Divide the RGB channels by [0.229, 0.224, 0.225], respectively.

In [None]:
class InpaintingLoss(nn.Module):
    def __init__(self, extractor, tv_loss='mean'):
        super(InpaintingLoss, self).__init__()
        self.tv_loss = tv_loss
        self.l1 = nn.L1Loss()
        self.extractor = extractor

    def forward(self, input, mask, output, gt):
        comp = mask * input + (1 - mask) * output

        tv_loss = total_variation_loss(comp, mask, self.tv_loss)
        
        hole_loss = self.l1((1-mask) * output, (1-mask) * gt)

        # Valid Pixel Loss
        valid_loss = self.l1(mask * output, mask * gt)

        # Perceptual Loss and Style Loss
        feats_out = self.extractor(output)
        feats_comp = self.extractor(comp)
        feats_gt = self.extractor(gt)
        perc_loss = 0.0
        style_loss = 0.0
        # Calculate the L1Loss for each feature map
        for i in range(3):
            perc_loss += self.l1(feats_out[i], feats_gt[i])
            perc_loss += self.l1(feats_comp[i], feats_gt[i])
            style_loss += self.l1(gram_matrix(feats_out[i]),
                                  gram_matrix(feats_gt[i]))
            style_loss += self.l1(gram_matrix(feats_comp[i]),
                                  gram_matrix(feats_gt[i]))

        return {'valid': valid_loss,
                'hole': hole_loss,
                'perc': perc_loss,
                'style': style_loss,
                'tv': tv_loss}


# The network of extracting the feature for perceptual and style loss
class VGG16FeatureExtractor(nn.Module):
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]

    def __init__(self):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        normalization = Normalization(self.MEAN, self.STD)
        # Define the each feature exractor
        self.enc_1 = nn.Sequential(normalization, *vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])

        # fix the encoder
        for i in range(3):
            for param in getattr(self, 'enc_{}'.format(i+1)).parameters():
                param.requires_grad = False

    def forward(self, input):
        feature_maps = [input]
        for i in range(3):
            feature_map = getattr(self, 'enc_{}'.format(i+1))(feature_maps[-1])
            feature_maps.append(feature_map)
        return feature_maps[1:]


# Normalization Layer for VGG
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, input):
        # normalize img
        if self.mean.type() != input.type():
            self.mean = self.mean.to(input)
            self.std = self.std.to(input)
        return (input - self.mean) / self.std


# Gram Matrix of feature maps
def gram_matrix(feat):
    (b, ch, h, w) = feat.size()
    feat = feat.view(b, ch, h * w)
    feat_t = feat.transpose(1, 2)
    gram = torch.bmm(feat, feat_t) / (ch * h * w)
    return gram


def dialation_holes(hole_mask):
    b, ch, h, w = hole_mask.shape
    dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(hole_mask)
    torch.nn.init.constant_(dilation_conv.weight, 1.0)
    with torch.no_grad():
        output_mask = dilation_conv(hole_mask)
    updated_holes = output_mask != 0
    return updated_holes.float()


def total_variation_loss(image, mask, method):
    hole_mask = 1 - mask
    dilated_holes = dialation_holes(hole_mask)
    colomns_in_Pset = dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
    rows_in_Pset = dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
    if method == 'sum':
        loss = torch.sum(torch.abs(colomns_in_Pset*(
                    image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
            torch.sum(torch.abs(rows_in_Pset*(
                    image[:, :, :1, :] - image[:, :, -1:, :])))
    else:
        loss = torch.mean(torch.abs(colomns_in_Pset*(
                    image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
            torch.mean(torch.abs(rows_in_Pset*(
                    image[:, :, :1, :] - image[:, :, -1:, :])))
    return loss

def main():
  vgg = VGG16FeatureExtractor()
  criterion = InpaintingLoss(VGG16FeatureExtractor(), vgg)

  img = torch.randn(1, 3, 500, 500)
  mask = torch.ones((1, 1, 500, 500))
  mask[:, :, 250:, :][:, :, :, 250:] = 0
  input = img * mask
  out = torch.randn(1, 3, 500, 500)
  loss = criterion(input, mask, out, img)

if __name__ == '__main__':
    main()

## Evaluation of the model

In [None]:
def evaluate(model, dataset, device, filename):
    print('Start the evaluation')
    model.eval()
    image, mask, gt = zip(*[dataset[i] for i in range(8)])
    image = torch.stack(image)
    mask = torch.stack(mask)
    gt = torch.stack(gt)
    with torch.no_grad():
        output, _ = model(image.to(device), mask.to(device))
    output = output.to(torch.device('cpu'))
    output_comp = mask * image + (1 - mask) * output

    grid = make_grid(torch.cat([image, mask, output, output_comp, gt], dim=0))
    save_image(grid, filename)

## Utility Functions

In [None]:
def create_ckpt_dir():
    ckpt_dir = "ckpt"
    if (os.path.exists("ckpt")):
      return ckpt_dir
    else: 
      os.mkdir(ckpt_dir)
      os.mkdir(os.path.join(ckpt_dir, "val_vis"))
      os.mkdir(os.path.join(ckpt_dir, "models"))
      return ckpt_dir


def to_items(dic):
    return dict(map(_to_item, dic.items()))


def _to_item(item):
    return item[0], item[1].item()


class Config(dict):
    def __init__(self, conf_file):
        with open(conf_file, "r") as f:
            config = yaml.safe_load(f)
        self._conf = config

    def __getattr__(self, name):
        if self._conf.get(name) is None:
            return None

        return self._conf[name]


def conf_to_param(config: dict) -> dict:
    dind_keys = []
    rm_keys = []
    for key, val in config.items():
        if isinstance(val, dict):
            dind_keys.append(key)
        elif not type(val) in [float, int, bool, str]:
            rm_keys.pop(key)

    for target in dind_keys:
        val = config.pop(target)
        config.update(val)
    for target in rm_keys:
        del config[target]

    return config


def get_state_dict_on_cpu(obj):
    cpu_device = torch.device("cpu")
    state_dict = obj.state_dict()
    for key in state_dict.keys():
        state_dict[key] = state_dict[key].to(cpu_device)
    return state_dict


def save_ckpt(ckpt_name, models, optimizers, n_iter):
    ckpt_dict = {"n_iter": n_iter}
    for prefix, model in models:
        ckpt_dict[prefix] = get_state_dict_on_cpu(model)

    for prefix, optimizer in optimizers:
        ckpt_dict[prefix] = optimizer.state_dict()
    torch.save(ckpt_dict, ckpt_name)


def load_ckpt(ckpt_name, models, optimizers=None):
    ckpt_dict = torch.load(ckpt_name)
    for prefix, model in models:
        assert isinstance(model, nn.Module)
        model.load_state_dict(ckpt_dict[prefix], strict=False)
    if optimizers is not None:
        for prefix, optimizer in optimizers:
            optimizer.load_state_dict(ckpt_dict[prefix])
    return ckpt_dict["n_iter"]

## The place where the training actually happens

In [None]:
class Trainer(object):
    def __init__(self, step, config, device, model, dataset_train,
                 dataset_val, criterion, optimizer):
        self.stepped = step
        self.config = config
        self.device = device
        self.model = model
        self.dataloader_train = DataLoader(dataset_train,
                                           batch_size=config.batch_size,
                                           shuffle=True)
        self.dataset_val = dataset_val
        self.criterion = criterion
        self.optimizer = optimizer
        self.evaluate = evaluate

    def iterate(self):
        print('Start the training')
        for step, (input, mask, gt) in enumerate(self.dataloader_train):
            loss_dict = self.train(step+self.stepped, input, mask, gt)
            # report the loss
            if step % self.config.log_interval == 0:
                self.report(step+self.stepped, loss_dict)

            # evaluation
            if (step+self.stepped + 1) % self.config.vis_interval == 0 \
                    or step == 0 or step + self.stepped == 0:
                
                self.model.eval()
                self.evaluate(self.model, self.dataset_val, self.device,
                              '{}/val_vis/{}.png'.format(self.config.ckpt,
                                                         step+self.stepped))

            # save the model
            if (step+self.stepped + 1) % self.config.save_model_interval == 0 \
                    or (step + 1) == self.config.max_iter:
                print('Saving the model...')
                save_ckpt('{}/models/{}.pth'.format(self.config.ckpt,
                                                    step+self.stepped + 1),
                          [('model', self.model)],
                          [('optimizer', self.optimizer)],
                          step+self.stepped + 1)

            if step >= self.config.max_iter:
                break

    def train(self, step, input, mask, gt):
        
        self.model.train()

        input = input.to(self.device)
        mask = mask.to(self.device)
        gt = gt.to(self.device)

        # forward
        output, _ = self.model(input, mask)
        loss_dict = self.criterion(input, mask, output, gt)
        loss = 0.0
        for key, val in loss_dict.items():
            coef = getattr(self.config, '{}_coef'.format(key))
            loss += coef * val

        # updates the model's params
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        loss_dict['total'] = loss
        return to_items(loss_dict)

    def report(self, step, loss_dict):
        print('[STEP: {:>6}] | Valid Loss: {:.6f} | Hole Loss: {:.6f}'\
              '| TV Loss: {:.6f} | Perc Loss: {:.6f}'\
              '| Style Loss: {:.6f} | Total Loss: {:.6f}'.format(
                        step, loss_dict['valid'], loss_dict['hole'],
                        loss_dict['tv'], loss_dict['perc'],
                        loss_dict['style'], loss_dict['total']))

## Training Startpoint

In [None]:
config = Config("location of the config file")
config.ckpt = create_ckpt_dir()
print("Check Point is '{}'".format(config.ckpt))


device = torch.device("cuda:{}".format(config.cuda_id)
                      if torch.cuda.is_available() else "cpu")


print("Loading the Model...")
model = PConvUNet(finetune=config.finetune,
                  layer_size=config.layer_size)

if config.finetune:
    model.load_state_dict(torch.load(config.finetune)['model'])
model.to(device)

# Data Transformation
img_tf = transforms.Compose([
            transforms.ToTensor()
            ])

mask_tf = transforms.Compose([
            transforms.RandomResizedCrop(256),
            transforms.ToTensor()
            ])


print("Loading the Validation Dataset...")
dataset_val = InitDataset(config.data_root,
                      img_tf,
                      mask_tf,
                      data="val")


print("Loading the Training Dataset...")
dataset_train = InitDataset(config.data_root,
                        img_tf,
                        mask_tf,
                        data="train")

# Loss fucntion
criterion = InpaintingLoss(VGG16FeatureExtractor(),
                            tv_loss=config.tv_loss).to(device)
# Optimizer
lr = config.finetune_lr if config.finetune else config.initial_lr
if config.optim == "Adam":
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                    lr=lr,
                                    weight_decay=config.weight_decay)
elif config.optim == "SGD":
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                lr=lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

start_iter = 0
trainer = Trainer(start_iter, config, device, model, dataset_train,
                    dataset_val, criterion, optimizer)
trainer.iterate()

## Resources

- "Image Inpainting for Irregular Holes Using Partial Convolutions", https://arxiv.org/abs/1804.07723

## Citation

```
@inproceedings{liu2018partialpadding,
   author    = {Guilin Liu and Kevin J. Shih and Ting-Chun Wang and Fitsum A. Reda and Karan Sapra and Zhiding Yu and Andrew Tao and Bryan Catanzaro},
   title     = {Partial Convolution based Padding},
   booktitle = {arXiv preprint arXiv:1811.11718},   
   year      = {2018},
}
@inproceedings{liu2018partialinpainting,
   author    = {Guilin Liu and Fitsum A. Reda and Kevin J. Shih and Ting-Chun Wang and Andrew Tao and Bryan Catanzaro},
   title     = {Image Inpainting for Irregular Holes Using Partial Convolutions},
   booktitle = {The European Conference on Computer Vision (ECCV)},   
   year      = {2018},
}
```