In [1]:
import warnings
warnings.filterwarnings('ignore')
!pip install -qU torchsummary
from torchsummary import summary

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torch.nn.utils import spectral_norm
import torch.autograd as autograd

import cv2
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split
from functools import partial

seed = 42

In [2]:
torch.random.manual_seed(seed)
np.random.seed(seed)

BATCH_SIZE = 2
IMAGE_HEIGHT = 640
IMAGE_WIDTH = 360
epochs = 300
FINE_SIZE = 256
gp_lambda = 10
content_loss_lambda = 100

PATH = r'./deblur_model.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using {device.upper()} device')

In [3]:
path = '../input/blur-dataset/'
blurred = glob(path + 'motion_blurred' + '/*')
sharp = glob(path + 'sharp' + '/*')

blurred = sorted([str(x) for x in blurred])
sharp = sorted([str(x) for x in sharp])

df = pd.DataFrame(data={'blur': blurred, 'sharp': sharp})
df.sample(5)

In [4]:
train, test = train_test_split(df, test_size=5, shuffle=True, random_state=123)
train, valid = train_test_split(train, test_size=30, shuffle=True, random_state=123)
print(f'Train size: {train.shape[0]}, valid size: {valid.shape[0]}, test size: {test.shape[0]}')

In [5]:
train_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BICUBIC),
                              T.RandomHorizontalFlip(p=0.2),
                              T.ToTensor(),
                              T.Normalize(0.5, 0.5),
])
valid_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT), Image.BICUBIC),
                              T.ToTensor(),
                              T.Normalize(0.5, 0.5),
])
invTrans = T.Compose([ T.Normalize(mean = [ 0., 0., 0. ],
                                   std = [ 1/0.5, 1/0.5, 1/0.5 ]),
                       T.Normalize(mean = [ -0.5, -0.5, -0.5 ],
                                   std = [ 1., 1., 1. ]),
                               ])

In [6]:
class BlurDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, ix):
        row = self.df.iloc[ix].squeeze()
        blurred_img = cv2.imread(row['blur'])
        blurred_img = cv2.cvtColor(blurred_img, cv2.COLOR_BGR2RGB)
        sharp_img = cv2.imread(row['sharp'])
        sharp_img = cv2.cvtColor(sharp_img, cv2.COLOR_BGR2RGB)
        return blurred_img, sharp_img

    def collate_fn(self, batch):
        blurs, sharps = list(zip(*batch))
        blurs = [self.transforms(img)[None] for img in blurs]
        sharps = [self.transforms(img)[None] for img in sharps]
        blurs, sharps = [torch.cat(i).to(device) for i in [blurs, sharps]]
        return blurs, sharps

In [7]:
train_dataset = BlurDataset(train, transforms=train_transforms)
valid_dataset = BlurDataset(valid, transforms=valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=train_dataset.collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=valid_dataset.collate_fn)

# imgA = invTrans(imgA).squeeze().detach().cpu().numpy().transpose(1,2,0)

In [9]:
def get_norm_layer():
    return partial(nn.InstanceNorm2d, track_running_stats=False)

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1, 0.02)
        nn.init.zeros_(m.bias)

class ResNetBlock(nn.Module):
    def __init__(self, dim, norm_layer, use_bias):
        super(ResNetBlock, self).__init__()
        sequence = list()

        sequence += [nn.ReflectionPad2d(1)]

        sequence += [
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
            norm_layer(dim),
            nn.ReLU(True)
        ]

        sequence += [nn.Dropout(0.5)]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        out = x + self.model(x)
        return out

class Generator(nn.Module):

    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
        super(Generator, self).__init__()

        norm_layer = get_norm_layer()
        use_bias = norm_layer.func != nn.BatchNorm2d

        sequence = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True)
        ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            sequence += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                norm_layer(ngf * mult * 2),
                nn.ReLU(True)
            ]

        for i in range(n_blocks):
            sequence += [
                ResNetBlock(ngf * 2 ** n_downsampling, norm_layer, use_bias)
            ]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            sequence += [
                nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
                                   output_padding=1, bias=use_bias),
                norm_layer(int(ngf * mult / 2)),
                nn.ReLU(True)
            ]
        sequence += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        out = self.model(x)
        out = x + out
        out = torch.clamp(out, min=-1, max=1)
        return out
    
class Discriminator(nn.Module):

    def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False):
        super(Discriminator, self).__init__()

        norm_layer = get_norm_layer()
        use_bias = norm_layer.func != nn.BatchNorm2d

        kernel_size = 4
        padding = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=2, padding=padding),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kernel_size, stride=2, padding=padding,
                          bias=use_bias), n_power_iterations=2),
                #norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kernel_size, stride=1, padding=padding,
                      bias=use_bias), n_power_iterations=2),
            #norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kernel_size, stride=1, padding=padding)
        ]
        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        out = self.model(x)
        return out

In [10]:
generator = Generator(3, 3, n_blocks=9).apply(init_weights).to(device)
discriminator = Discriminator(3).apply(init_weights).to(device)

CONV3_3_IN_VGG_19 = torchvision.models.vgg19(pretrained=True, progress=False).features[:15].to(device)

In [None]:
summary(generator, (3, IMAGE_WIDTH, IMAGE_HEIGHT))

In [None]:
summary(discriminator, (3, IMAGE_WIDTH, IMAGE_HEIGHT))

In [None]:
summary(CONV3_3_IN_VGG_19, (3, IMAGE_WIDTH, IMAGE_HEIGHT))

In [11]:
def load_model(path, device=device):
    if device == 'cuda':
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
    epoch = checkpoint['epoch']
    generator = checkpoint['G']
    discriminator = checkpoint['D']
    optimizerG = checkpoint['optimizerG']
    optimizerD = checkpoint['optimizerD']
    return generator, discriminator, optimizerG, optimizerD, epoch

def PSNR(deblurred, sharp):
    mse = torch.mean((deblurred - sharp) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 1
    return 10 * np.log10(PIXEL_MAX ** 2 / mse)

class WGANLoss(nn.Module):
    def forward(self, mtype, **kwargs):
        if mtype == 'G':
            deblurred_discriminator_out = kwargs['deblurred_discriminator_out']
            return -deblurred_discriminator_out.mean()

        elif mtype == 'D':  
            gp_lambda = kwargs['gp_lambda']
            interpolates = kwargs['interpolates']
            interpolates_discriminator_out = kwargs['interpolates_discriminator_out']
            sharp_discriminator_out = kwargs['sharp_discriminator_out']
            deblurred_discriminator_out = kwargs['deblurred_discriminator_out']

            wgan_loss = deblurred_discriminator_out.mean() - sharp_discriminator_out.mean()

            gradients = autograd.grad(outputs=interpolates_discriminator_out, inputs=interpolates,
                                      grad_outputs=torch.ones(interpolates_discriminator_out.size()).to(device),
                                      retain_graph=True,
                                      create_graph=True)[0]
            gradient_penalty = ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()

            return wgan_loss, gp_lambda * gradient_penalty

class ContentLoss(nn.Module):
    def forward(self, deblurred, sharp, model=CONV3_3_IN_VGG_19):
        deblurred_feature_map = model.forward(deblurred)
        sharp_feature_map = model.forward(sharp).detach()
        loss = nn.functional.mse_loss(deblurred_feature_map, sharp_feature_map)
        return loss
    
def simple_gan_loss(mtype, **kwargs):
    if mtype == 'G':
        deblurred_discriminator_out = kwargs['deblurred_discriminator_out']
        return nn.functional.binary_cross_entropy(deblurred_discriminator_out, torch.ones_like(deblurred_discriminator_out))

    elif mtype == 'D':
        sharp_discriminator_out = kwargs['sharp_discriminator_out']
        deblurred_discriminator_out = kwargs['deblurred_discriminator_out']
        real_loss = nn.functional.binary_cross_entropy(sharp_discriminator_out, torch.ones_like(sharp_discriminator_out))
        fake_loss = nn.functional.binary_cross_entropy(deblurred_discriminator_out, torch.zeros_like(deblurred_discriminator_out))
        return (real_loss + fake_loss) / 2.0

In [15]:
criterion_wgan = WGANLoss()
criterion_content = ContentLoss()

optimizerG = torch.optim.AdamW(generator.parameters(), lr=0.00001, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-6)
optimizerD = torch.optim.AdamW(discriminator.parameters(), lr=0.00001, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-6)

lr_lambda = lambda epoch: (1 - (epoch - 150) / 150) if epoch > 150 else 1
schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda=lr_lambda)
schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda=lr_lambda)

In [16]:
def denormalize(image_tensor):
    return (image_tensor + 1) / 2.0

def train_one_batch(generator, discriminator, data, criterionW, criterionC, optimizerG, optimizerD, critic_updates=5):
    generator.train()
    discriminator.train()

    blur, sharp = data

    discriminator_loss = 0
    for i in range(critic_updates):

        deblur = generator(blur) # before loop initially, deblur_ = deblur.clone()

        d_sharp_out = discriminator(sharp) # before loop initially
        d_deblur_out = discriminator(deblur) # before loop initially

        optimizerD.zero_grad()
        alpha = np.random.random()
        interpolates = alpha * sharp + (1 - alpha) * deblur
        interpolates_discriminator_out = discriminator(interpolates)
        kwargs = {
                  'gp_lambda': gp_lambda,
                  'interpolates': interpolates, 
                  'interpolates_discriminator_out': interpolates_discriminator_out, 
                  'sharp_discriminator_out': d_sharp_out, 
                  'deblurred_discriminator_out': d_deblur_out,  
                  }
        wgan_loss_d, gp_d = criterionW('D', **kwargs)
        discriminator_loss_per_update = wgan_loss_d + gp_d
        discriminator_loss_per_update.backward(retain_graph=True)
        optimizerD.step()
        discriminator_loss += discriminator_loss_per_update.item()
    discriminator_loss /= critic_updates

    optimizerG.zero_grad()
    
    deblur = generator(blur) ###
    d_deblur_out = discriminator(deblur) ###
    kwargs = {
              'deblurred_discriminator_out': d_deblur_out, 
              }    
    wgan_loss_g = criterionW('G', **kwargs)
    content_loss_g = criterionC(deblur, sharp) * content_loss_lambda
    generator_loss = wgan_loss_g + content_loss_g
    generator_loss.backward()
    optimizerG.step()

    with torch.no_grad():
        denormalized_sharp = denormalize(sharp).cpu().detach()
        denormalized_deblurred = denormalize(deblur).cpu().detach()

    metric = PSNR(denormalized_deblurred, denormalized_sharp)

    torch.nn.utils.clip_grad_norm_(generator.parameters(), 10.0)
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 10.0)

    #if device == 'cuda':
    #    torch.cuda.empty_cache()
    return discriminator_loss, generator_loss.item(), metric

@torch.no_grad()
def validate(generator, discriminator, data, criterionW, criterionC):
    generator.eval()
    discriminator.eval()

    blur, sharp = data
    deblur = generator(blur)
    d_deblur = discriminator(deblur)

    kwargs = {
              'deblurred_discriminator_out': d_deblur, }
    adversarial_loss_g = criterionW('G', **kwargs)
    content_loss_g = criterionC(deblur, sharp) * content_loss_lambda
    loss_g = adversarial_loss_g + content_loss_g

    denormalized_sharp = denormalize(sharp).cpu().detach()
    denormalized_deblurred = denormalize(deblur).cpu().detach()

    metric = PSNR(denormalized_deblurred, denormalized_sharp)

    return loss_g.item(), metric

@torch.no_grad()
def visual_validate(data, model):
    img, tar = data
    model.eval()
    out = model(img)
    out, img, tar = [denormalize(tensor) for tensor in [out, img, tar]]
    out, img, tar = [tensor.squeeze().cpu().detach().numpy().transpose(1,2,0) for tensor in [out, img, tar]]
    plt.figure(figsize=(12,14))
    plt.subplot(131)
    plt.title('Blurred')
    plt.imshow(img)
    plt.subplot(132)
    plt.title('Target')
    plt.imshow(tar)
    plt.subplot(133)
    plt.title('Deblurred')
    plt.imshow(out)
    plt.show()
    plt.pause(0.001)

In [17]:
train_d_losses, train_g_losses, valid_g_losses = [], [], []
train_metric_total, valid_metric_total = [], []

try:
    generator, discriminator, optimizerG, optimizerD, ep = load_model(PATH)
except FileNotFoundError:
    ep = 0

for epoch in range(ep, epochs):
    print(f'Epoch {epoch + 1}/{epochs}')

    train_epoch_d_loss, train_epoch_g_loss, train_epoch_metric = [],[],[]
    for i, data in enumerate(tqdm(train_dataloader, leave=False)): 
        with autograd.set_detect_anomaly(True): 
            d_loss, g_loss, metric = train_one_batch(generator, discriminator, data, criterion_wgan, criterion_content,
                                                 optimizerG, optimizerD, critic_updates=5)
        train_epoch_d_loss.append(d_loss)
        train_epoch_g_loss.append(g_loss)
        train_epoch_metric.append(metric)
    epoch_d_loss = np.array(train_epoch_d_loss).mean()
    epoch_g_loss = np.array(train_epoch_g_loss).mean()
    train_metric = np.array(train_epoch_metric).mean()
    train_d_losses.append(epoch_d_loss)
    train_g_losses.append(epoch_g_loss)
    train_metric_total.append(train_metric)
    print(f'Train D loss: {epoch_d_loss:.4f}, train G loss: {epoch_g_loss:.4f}')
    print(f'Train metric: {train_metric:.4f}')

    valid_epoch_g_loss, valid_epoch_metric = [],[]
    for i, data in enumerate(tqdm(valid_dataloader, leave=False)):
        g_loss, metric = validate(generator, discriminator, data, criterion_wgan, criterion_content)
        valid_epoch_g_loss.append(g_loss)
        valid_epoch_metric.append(metric)
    epoch_g_loss = np.array(valid_epoch_g_loss).mean()
    valid_metric = np.array(valid_epoch_metric).mean()
    valid_g_losses.append(epoch_g_loss)
    valid_metric_total.append(valid_metric)
    print(f'Valid G loss: {epoch_g_loss:.4f}')
    print(f'Valid metric: {valid_metric:.4f}')
    print('-'*50)    
    schedulerD.step()
    schedulerG.step()
    if (epoch + 1) % 2 == 0:
        checkpoint = {
                      'epoch': epoch,     
                      'G': generator,
                      'D': discriminator,
                      'optimizerG': optimizerG,
                      'optimizerD': optimizerD,
                      }
        torch.save(checkpoint, PATH)
        data = next(iter(valid_dataloader))
        visual_validate(data, generator)