In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/gopro_deblur.zip -d /content/drive/MyDrive/gopro_deblur


In [14]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pathlib
from pathlib import Path


In [None]:
import glob
blur_paths = glob.glob('/content/drive/MyDrive/gopro_deblur/gopro_deblur/blur/images/*.png')
sharp_paths = glob.glob('/content/drive/MyDrive/gopro_deblur/gopro_deblur/sharp/images/*.png')
blur_list = list(blur_paths)
sharp_list = list(sharp_paths)

In [16]:
from pathlib import Path
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class GoproDataset(Dataset):
    def __init__(self, blurred_image_paths, sharp_image_paths, transform = None):
        self.blurred_image_paths = blurred_image_paths
        self.sharp_image_paths = sharp_image_paths
        self.transform = transform

    def __len__(self):
        return len(self.blurred_image_paths)

    def __getitem__(self,idx):
        blurred_image = Image.open(self.blurred_image_paths[idx])
        sharp_image = Image.open(self.sharp_image_paths[idx])

        if self.transform:
            blurred_image = self.transform(blurred_image)
            sharp_image = self.transform(sharp_image)
        return blurred_image, sharp_image

In [17]:
data_transform = transforms.Compose([
    transforms.Resize((640, 360)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [26]:
from torch.utils.data import random_split

dataset = GoproDataset(blur_list,sharp_list, data_transform)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [19]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.Convbock = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True)
        )

    def forward(self,x):
        return self.Convbock(x)

class Contract(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()

        self.contract = nn.Sequential(
        nn.MaxPool2d(kernel_size = 2),
        ConvBlock(in_channels, out_channels)
        )

    def forward(self,x):
        return self.contract(x)

class Expand(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2], dim=1)
        return self.conv(x)
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()

        self.out = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=1),
        nn.Tanh()
        )

    def forward(self, x):
        return self.out(x)

In [20]:
class Unet_generator(nn.Module):
    def __init__(self):
        super(Unet_generator,self).__init__()

        self.inc = (ConvBlock(3,64))
        self.down1 = (Contract(64,128))
        self.down2 = (Contract(128,256))
        self.down3 = (Contract(256,512))
        # self.down4 = (Contract(512,1024))

        # self.up1 = (Expand(1024, 512))
        self.up2 = (Expand(512, 256))
        self.up3 = (Expand(256, 128))
        self.up4 = (Expand(128, 64))
        self.outc = (OutConv(64, 3))


    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        # x5 = self.down4(x4)
        # x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        # self.down4 = torch.utils.checkpoint(self.down4)
        # self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [21]:
class Discriminator(nn.Module):
    def __init__(self, input_shape, ndf, n_layers=3, use_sigmoid=False):
        super(Discriminator, self).__init__()

        layers = [nn.Conv2d(in_channels=input_shape[0], out_channels=ndf, kernel_size=4, stride=2, padding=1)]
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        nf_mult, nf_mult_prev = 1, 1
        for n in range(n_layers):
            nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
            layers.append(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(ndf * nf_mult))
            layers.append(nn.LeakyReLU(0.2, inplace=True))

        nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
        layers.append(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1))
        layers.append(nn.BatchNorm2d(ndf * nf_mult))
        layers.append(nn.LeakyReLU(0.2, inplace=True))

        layers.append(nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1))

        if use_sigmoid:
            layers.append(nn.Sigmoid())

        self.model = nn.Sequential(*layers)

        self.feature_size = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(self.feature_size, 1024)
        self.fc2 = nn.Linear(1024, 1)

    def _get_conv_output(self, shape):

        x = torch.rand(1, *shape)
        x = self.model(x)
        return int(torch.prod(torch.tensor(x.shape)))

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        x = torch.tanh(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

In [22]:
netG = Unet_generator()
netD = Discriminator((3,640,360),64)

In [23]:
import torch.optim as optim

criterion = nn.BCELoss()
criterion_L1 = nn.L1Loss()

optimizer_D = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

netG = netG.to(device)
netD = netD.to(device)

In [24]:
def train(dataloader, num_epochs):
    for epoch in range(num_epochs):
        for i, (blurred_images, sharp_images) in enumerate(train_dataloader):
            blurred_images = blurred_images.to(device)
            sharp_images = sharp_images.to(device)

            real_labels = torch.ones((blurred_images.size(0)), dtype=torch.float, device=device)
            fake_labels = torch.zeros((blurred_images.size(0)), dtype=torch.float, device=device)

            netD.zero_grad()

    
            output_real = netD(sharp_images).view(-1)
            loss_real = criterion(output_real, real_labels)

            
            fake_images = netG(blurred_images)
            output_fake = netD(fake_images.detach()).view(-1)

            loss_fake = criterion(output_fake, fake_labels)

            # Total Discriminator loss
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            #### Train Generator ####
            netG.zero_grad()

            # GAN loss (adversarial loss)
            output_fake_for_G = netD(fake_images).view(-1)
            loss_G_GAN = criterion(output_fake_for_G, real_labels)

            # L1 Loss for pixel-wise image similarity
            loss_G_L1 = criterion_L1(fake_images, sharp_images) * 100  # Weighted L1 loss

            # Total Generator loss
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            optimizer_G.step()

            fake_images.detach()
            torch.cuda.empty_cache()



            if i % 10 == 0:
                print(f"Epoch [{epoch}/{num_epochs}] | Batch [{i}/{len(dataloader)}] | Loss D: {loss_D.item()} | Loss G: {loss_G.item()}")


In [None]:
train(train_dataloader, 1)

In [34]:
for i, (blurred_images, sharp_images) in enumerate(test_dataloader):

    blurred_images = blurred_images.to(device)
    sharp_images = sharp_images.to(device)

    with torch.no_grad():  
        output_images = netG(blurred_images)

    break


In [35]:
def denormalize(tensor):
    tensor = tensor * 0.5 + 0.5  # Reverses the normalization step 
    return tensor

In [None]:
from PIL import Image
from IPython.display import display

blurred_img = denormalize(blurred_images[0].cpu())
output_img = denormalize(output_images[0].cpu())
sharp_img = denormalize(sharp_images[0].cpu())

blurred_img = transforms.ToPILImage()(blurred_img)
output_img = transforms.ToPILImage()(output_img)
sharp_img = transforms.ToPILImage()(sharp_img)

# Display the images
display(blurred_img)
display(output_img)
display(sharp_img)