In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataloaders.mask_generator import MaskGenerator
from dataloaders.images_dataset import ImagesDataset
from torch.utils.data import DataLoader

from models.pconv_unet import PConvUNet
from models.vgg16_extractor import VGG16Extractor

from loss.loss_compute import LossCompute

from utils.preprocessing import Preprocessor

In [3]:
from pytorch_lightning import Trainer

In [4]:
HEIGHT, WIDTH = 256, 256
invert_mask = False
mask_dir = "../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/"
train_dir = "../Repos/image-inpainting/dataset/train_0"
NUM_WORKERS = 2
BS = 2
LR = 0.0002

In [5]:
import os
import torch
import pytorch_lightning as pl

class ImageInpaintingSystem(pl.LightningModule):

    def __init__(self):
        super(ImageInpaintingSystem, self).__init__()
        self.pConvUNet = PConvUNet()
        
        vgg16extractor = VGG16Extractor().to("cuda")
        self.lossCompute = LossCompute(vgg16extractor)
        
        self.preprocess = Preprocessor("cuda")

    def forward(self, masked_img_tensor, mask_tensor):
        return self.pConvUNet(masked_img_tensor, mask_tensor)

    def training_step(self, batch, batch_nb):
        masked_img, mask, image  = batch
        
        img_tensor = self.preprocess.normalize(image.type(torch.float))
        mask_tensor = mask.type(torch.float).transpose(1, 3)
        masked_img_tensor = self.preprocess.normalize(masked_img.type(torch.float))
        
        ls_fn = self.lossCompute.loss_total(mask_tensor)
        
        output = self.forward(masked_img_tensor, mask_tensor)
        
        loss = ls_fn(img_tensor, output).mean()
        psnr = self.lossCompute.PSNR(img_tensor,output)
        
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    @pl.data_loader
    def train_dataloader(self):
        mask_generator = MaskGenerator(mask_dir, HEIGHT,WIDTH, invert_mask=invert_mask) 
        dataset = ImagesDataset(train_dir, HEIGHT, WIDTH, mask_generator)
        dataloader = DataLoader(dataset, batch_size=BS, shuffle=True, num_workers=NUM_WORKERS)
        return dataloader


In [6]:
model = ImageInpaintingSystem()

trainer = Trainer(gpus=1, use_amp=False)
trainer.fit(model)

gpu available: True, used: True
VISIBLE GPUS: 0
55116 masks found: ../Repos/image-inpainting/dataset/irregular_mask/irregular_mask/disocclusion_img_mask/
                             Name           Type Params
0                       pConvUNet      PConvUNet   32 M
1              pConvUNet.encoder1   PConvEncoder    9 K
2        pConvUNet.encoder1.pconv  PartialConv2d    9 K
3    pConvUNet.encoder1.batchnorm    BatchNorm2d  128  
4   pConvUNet.encoder1.activation           ReLU    0  
..                            ...            ...    ...
78       pConvUNet.decoder8.pconv  PartialConv2d    1 K
79   pConvUNet.decoder8.batchnorm    BatchNorm2d    6  
80  pConvUNet.decoder8.activation      LeakyReLU    0  
81            pConvUNet.convfinal         Conv2d   12  
82              pConvUNet.sigmoid        Sigmoid    0  

[83 rows x 3 columns]


  dilated_mask = torch.tensor(dilated_mask> 0, dtype=torch.float, requires_grad=False).to(self.device)
  0%|          | 344/78271 [01:29<5:26:17,  3.98it/s, batch_nb=342, epoch=0, gpu=0, loss=6.702, v_nb=22] 

KeyboardInterrupt: 

In [7]:
print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
print('and going to http://localhost:6006 on your browser')

View tensorboard logs by running
tensorboard --logdir D:\Image-Inpainting
and going to http://localhost:6006 on your browser
