In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from dataloaders.mask_generator import MaskGenerator
from dataloaders.images_dataset import ImagesDataset
from torch.utils.data import DataLoader

from models.pconv_unet import PConvUNet
from models.vgg16_extractor import VGG16Extractor

from loss.loss_compute import LossCompute

from utils.preprocessing import Preprocessor

In [3]:
import numpy as np
from PIL import Image

In [4]:
from pytorch_lightning import Trainer

In [5]:
HEIGHT, WIDTH = 256, 256
invert_mask = False
mask_dir = "../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/"
train_dir = "../Repos/image-inpainting/dataset/train_0"
valid_dir = "../Repos/image-inpainting/dataset/test"

NUM_WORKERS = 2
BS = 2
LR = 0.0002

In [6]:
import os
import torch
import pytorch_lightning as pl

class ImageInpaintingSystem(pl.LightningModule):

    def __init__(self):
        super(ImageInpaintingSystem, self).__init__()
        self.pConvUNet = PConvUNet()
        
        vgg16extractor = VGG16Extractor().to("cuda")
        for param in vgg16extractor.parameters():
            param.requires_grad = False
        self.lossCompute = LossCompute(vgg16extractor)
        
        self.preprocess = Preprocessor("cuda")

    def forward(self, masked_img_tensor, mask_tensor):
        return self.pConvUNet(masked_img_tensor, mask_tensor)

    def training_step(self, batch, batch_nb):
        masked_img, mask, image  = batch
        
        img_tensor = self.preprocess.normalize(image.type(torch.float))
        mask_tensor = mask.type(torch.float).transpose(1, 3)
        masked_img_tensor = self.preprocess.normalize(masked_img.type(torch.float))
        
        ls_fn = self.lossCompute.loss_total(mask_tensor)
        output = self.forward(masked_img_tensor, mask_tensor)
        loss = ls_fn(img_tensor, output).mean()
        
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        masked_img, mask, image  = batch
        
        img_tensor = self.preprocess.normalize(image.type(torch.float))
        mask_tensor = mask.type(torch.float).transpose(1, 3)
        masked_img_tensor = self.preprocess.normalize(masked_img.type(torch.float))
        
        ls_fn = self.lossCompute.loss_total(mask_tensor)
        output = self.forward(masked_img_tensor, mask_tensor)
        loss = ls_fn(img_tensor, output)
        
        psnr = self.lossCompute.PSNR(img_tensor, output)

        res = np.clip(self.preprocess.unnormalize(output).detach().cpu().numpy(),0,1)
        original_img = np.clip(self.preprocess.unnormalize(masked_img_tensor).detach().cpu().numpy(),0,1)
        combined_img = np.concatenate((original_img[0], res[0]))
        self.logger.experiment.add_image('images', combined_img, dataformats='HWC')   
        return {'val_loss': loss, 'psnr': psnr}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_psnr = torch.stack([x['psnr'] for x in outputs]).mean()
        tqdm_dict = {'valid_psnr': avg_psnr, 'valid_loss': avg_loss}
        return {'log':tqdm_dict,'valid_loss': avg_loss, 'valid_psnr': avg_psnr}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    @pl.data_loader
    def train_dataloader(self):
        mask_generator = MaskGenerator(mask_dir, HEIGHT,WIDTH, invert_mask=invert_mask) 
        dataset = ImagesDataset(train_dir, HEIGHT, WIDTH, mask_generator)
        dataloader = DataLoader(dataset, batch_size=BS, shuffle=True, num_workers=NUM_WORKERS)
        return dataloader
    
    @pl.data_loader
    def val_dataloader(self):
        mask_generator = MaskGenerator(mask_dir, HEIGHT,WIDTH, invert_mask=invert_mask) 
        dataset = ImagesDataset(valid_dir, HEIGHT, WIDTH, mask_generator)
        dataloader = DataLoader(dataset, batch_size=BS, shuffle=False, num_workers=NUM_WORKERS)
        return dataloader

In [None]:
model = ImageInpaintingSystem()

trainer = Trainer(gpus=1, train_percent_check=0.001, use_amp=False)
trainer.fit(model)

gpu available: True, used: True
VISIBLE GPUS: 0
55116 masks found: ../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/
55116 masks found: ../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/
                             Name           Type Params
0                       pConvUNet      PConvUNet   32 M
1              pConvUNet.encoder1   PConvEncoder    9 K
2        pConvUNet.encoder1.pconv  PartialConv2d    9 K
3    pConvUNet.encoder1.batchnorm    BatchNorm2d  128  
4   pConvUNet.encoder1.activation           ReLU    0  
..                            ...            ...    ...
78       pConvUNet.decoder8.pconv  PartialConv2d    1 K
79   pConvUNet.decoder8.batchnorm    BatchNorm2d    6  
80  pConvUNet.decoder8.activation      LeakyReLU    0  
81            pConvUNet.convfinal         Conv2d   12  
82              pConvUNet.sigmoid        Sigmoid    0  

[83 rows x 3 columns]


  dilated_mask = torch.tensor(dilated_mask> 0, dtype=torch.float, requires_grad=False).to(self.device)
 73%|███████▎  | 58/79 [00:19<00:05,  4.03it/s, batch_nb=56, epoch=43, gpu=0, loss=4.109, v_nb=4]