In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import cv2
import os
from PIL import Image

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms as T

from tqdm.notebook import tqdm

seed = 42

**Link to great advanced inpainting models**

* [lama](https://github.com/saic-mdal/lama)
* [csqiangwen](https://github.com/csqiangwen/DeepFillv2_Pytorch)

In [2]:
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

In [3]:
epochs = 40
batch_size = 16
lr = 8e-5
image_size = 128
mask_size = 64
path = r'painting_model.pth'
b1 = 0.5
b2 = 0.999

patch_h, patch_w = int(mask_size / 2 ** 3), int(mask_size / 2 ** 3)
patch = (1, patch_h, patch_w)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using "{device.upper()}" device.')

In [4]:
transforms = T.Compose([
    T.Resize((image_size, image_size), Image.BICUBIC),
    T.ToTensor(),
    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

inverse_transforms = T.Compose([
    T.Normalize(-1, 2),
    T.ToPILImage()
])

In [5]:
image_paths = sorted([str(p) for p in glob('../input/celebahq-resized-256x256/celeba_hq_256' + '/*.jpg')])

train, valid = train_test_split(image_paths, test_size=5000, shuffle=True, random_state=seed)
valid, test = train_test_split(valid, test_size=1000, shuffle=True, random_state=seed)
print(f'Train size: {len(train)}, validation size: {len(valid)}, test size: {len(test)}.')

In [6]:
class CelebaDataset(Dataset):
    def __init__(self, images_paths, transforms=transforms, train=True):
        self.images_paths = images_paths
        self.transforms = transforms
        self.train = train
        
    def __len__(self):
        return len(self.images_paths)
    
    def apply_center_mask(self, image):
        idx = (image_size - mask_size) // 2
        masked_image = image.clone()
        masked_image[:, idx:idx+mask_size, idx:idx+mask_size] = 1
        masked_part = image[:, idx:idx+mask_size, idx:idx+mask_size]
        return masked_image, idx
    
    def apply_random_mask(self, image):
        y1, x1 = np.random.randint(0, image_size-mask_size, 2)
        y2, x2 = y1 + mask_size, x1 + mask_size
        masked_part = image[:, y1:y2, x1:x2]
        masked_image = image.clone()
        masked_image[:, y1:y2, x1:x2] = 1
        return masked_image, masked_part
    
    def __getitem__(self, ix):
        path = self.images_paths[ix]
        image = Image.open(path)
        image = self.transforms(image)
        
        if self.train:
            masked_image, masked_part = self.apply_random_mask(image)
        else:
            masked_image, masked_part = self.apply_center_mask(image)
            
        return image, masked_image, masked_part
    
    def collate_fn(self, batch):
        images, masked_images, masked_parts = list(zip(*batch))
        images, masked_images, masked_parts = [[tensor[None].to(device) for tensor in ims] for ims in [images, masked_images, masked_parts]]
        images, masked_images, masked_parts = [torch.cat(ims) for ims in [images, masked_images, masked_parts]]
        return images, masked_images, masked_parts

In [7]:
train_dataset = CelebaDataset(train)
valid_dataset = CelebaDataset(valid, train=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)

In [8]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1, 0.02)
        nn.init.zeros_(m.bias)
        
def set_params(model, unfreeze):
    for param in model.parameters():
        param.requires_grad = unfreeze

In [9]:
!pip install -qq torchsummary
from torchsummary import summary

In [10]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            nn.Conv2d(64, channels, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [11]:
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize, dropout, spectral):
            """Returns layers of each discriminator block"""
            if spectral:
                layers = [nn.utils.spectral_norm(nn.Conv2d(in_filters, out_filters, 3, stride, 1), n_power_iterations=2)]
            else:
                layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            if dropout:
                layers.append(nn.Dropout(p=0.5))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize, dropout, spectral in [(64, 2, False, 0, 0), (128, 2, True, 0, 0), (256, 2, True, 0, 0), (512, 1, True, 0, 0)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize, dropout, spectral))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [12]:
generator = Generator().apply(init_weights).to(device)
discriminator = Discriminator().apply(init_weights).to(device)

In [13]:
summary(generator, (3, 128, 128))

In [14]:
summary(discriminator, (3, 64, 64))

Losses: wgan, perceptual, mse, l1, smoothl1

In [15]:
adversarial_loss = nn.MSELoss()
pixelwise_loss = nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))  # AdamW weight_decay=1e-3, eps=1e-2
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [16]:
def train_one_batch(batch, generator, discriminator, criterion_adv, criterion_pix, optimizer_G, optimizer_D):
    generator.train()
    discriminator.train()
    
    images, masked_images, masked_parts = batch
    real = torch.FloatTensor(batch_size, *patch).fill_(1.0).requires_grad_(False).to(device)  # (1.0 - lambda)
    fake = torch.FloatTensor(batch_size, *patch).fill_(0.0).requires_grad_(False).to(device)  # (lambda)
    
    set_params(discriminator, False)
    optimizer_G.zero_grad()
    gen_parts = generator(masked_images)
    
    gan_loss = criterion_adv(discriminator(gen_parts), real)
    pix_loss = criterion_pix(gen_parts, masked_parts)
    
    loss_g = 0.001 * gan_loss + 0.999 * pix_loss
    loss_g.backward()
    optimizer_G.step()
    
    set_params(discriminator, True)
    optimizer_D.zero_grad()

    real_loss = criterion_adv(discriminator(masked_parts), real)
    fake_loss = criterion_adv(discriminator(gen_parts.detach()), fake)
    
    loss_d = (real_loss + fake_loss) / 2
    loss_d.backward()
    optimizer_D.step()
    
    return loss_g.item(), loss_d.item()

@torch.no_grad()
def validate_one_batch(batch, generator, discriminator, criterion_adv, criterion_pix):
    generator.eval()
    discriminator.eval()
    
    images, masked_images, masked_parts = batch
    real = torch.FloatTensor(batch_size, *patch).fill_(1.0).requires_grad_(False).to(device)
    fake = torch.FloatTensor(batch_size, *patch).fill_(0.0).requires_grad_(False).to(device)
    
    gen_parts = generator(masked_images)
    
    gan_loss = criterion_adv(discriminator(gen_parts), real)
    pix_loss = criterion_pix(gen_parts, masked_parts)
    
    loss_g = 0.001 * gan_loss + 0.999 * pix_loss
    
    real_loss = criterion_adv(discriminator(masked_parts), real)
    fake_loss = criterion_adv(discriminator(gen_parts.detach()), fake)
    
    loss_d = (real_loss + fake_loss) / 2
    
    return loss_g.item(), loss_d.item()

@torch.no_grad()
def test_plot(test, generator):
    idx = np.random.randint(len(test))
    random_path = test[idx]
    
    image = Image.open(random_path)
    image = transforms(image)
    
    masked_image, idx = train_dataset.apply_center_mask(image)
    
    generator.eval()
    gen_part = generator(masked_image.unsqueeze(0).to(device)).squeeze(0).cpu().detach()
    gen_image = masked_image.clone()
    gen_image[:, idx:idx+mask_size, idx:idx+mask_size] = gen_part
    
    image = inverse_transforms(image)
    masked_image = inverse_transforms(masked_image)
    gen_image = inverse_transforms(gen_image)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(131)
    plt.title('Original Image')
    plt.imshow(image)
    
    plt.subplot(132)
    plt.title('Masked Image')
    plt.imshow(masked_image)
    
    plt.subplot(133)
    plt.title('Inpainted Image')
    plt.imshow(gen_image)
    
    plt.tight_layout()
    plt.show()
    plt.pause(0.01)

In [17]:
train_d_losses, valid_d_losses = [], []
train_g_losses, valid_g_losses = [], []

for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    tq_bar = tqdm(train_dataloader, total=len(train_dataloader), desc=f'Train step {epoch+1}')
    epoch_d_losses, epoch_g_losses = [], []
    for _, batch in enumerate(tq_bar):
        g_loss, d_loss = train_one_batch(batch, generator, discriminator, adversarial_loss, 
                                         pixelwise_loss, optimizer_G, optimizer_D)
        epoch_g_losses.append(g_loss)
        epoch_d_losses.append(d_loss)
        tq_bar.set_postfix(g_loss=np.mean(epoch_g_losses), d_loss=np.mean(epoch_d_losses))
    train_d_losses.append(np.mean(epoch_d_losses))
    train_g_losses.append(np.mean(epoch_g_losses))
    
    tq_bar = tqdm(valid_dataloader, total=len(valid_dataloader), desc=f'Validation step {epoch+1}')
    epoch_d_losses, epoch_g_losses = [], []
    for _, batch in enumerate(tq_bar):
        g_loss, d_loss = validate_one_batch(batch, generator, discriminator, adversarial_loss, pixelwise_loss)
        epoch_d_losses.append(d_loss)
        epoch_g_losses.append(g_loss)
        tq_bar.set_postfix(g_loss=np.mean(epoch_g_losses), d_loss=np.mean(epoch_d_losses))
    valid_d_losses.append(np.mean(epoch_d_losses))
    valid_g_losses.append(np.mean(epoch_g_losses))
    
    if (epoch+1) % 2 == 0 or (epoch+1) == epochs:
        test_plot(test, generator)
        checkpoint = {
            'discriminator': discriminator,
            'generator': generator,
        }
        torch.save(checkpoint, path)
        
plt.figure(figsize=(8, 4))
x_axis = np.arange(1,epochs+1)
plt.subplot(121)
plt.title('Discriminator train/valid losses')
plt.plot(x_axis, train_d_losses, label='train d_loss')
plt.plot(x_axis, valid_d_losses, label='valid d_loss')
plt.grid()
plt.legend()

plt.subplot(122)
plt.title('Genrator train/valid losses')
plt.plot(x_axis, train_g_losses, label='train g_loss')
plt.plot(x_axis, valid_g_losses, label='valid_g_loss')
plt.grid()
plt.legend()

plt.tight_layout()
plt.show()

### [Kernel outputs visualization.](https://www.kaggle.com/code/pankratozzi/pytorch-vaegan-face-inpainting?scriptVersionId=102242524)

* for non-fixed masked area (when mask_size is not defined) + upsample(64,64) [3, 128, 128] -> [3, 64->128,64->128]
* OR try UNET architecture with residual connections [3, 128, 128] -> [3, 128, 128]
* OR maybe CycleGAN or Pix2PixGAN
* OR, for example, this great paper: https://arxiv.org/pdf/1806.03589.pdf; [Fine&easy to understand implementation](https://github.com/csqiangwen/DeepFillv2_Pytorch)