In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

  def forward(self, x):
       return self.double_conv(x)

In [4]:
class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.avgpool_conv = nn.Sequential(
            nn.AvgPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.avgpool_conv(x)


In [5]:
class Up(nn.Module):
    

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [6]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return torch.sigmoid(self.conv(x))

In [7]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.inc = DoubleConv(in_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256 )
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        out = self.outc(x)
        return out


In [8]:
import os
sorted(os.listdir('./blur_dataset/motion_blurred'))
sorted(os.listdir('./blur_dataset/sharp'))

['0_IPHONE-SE_S.JPG',
 '100_NIKON-D3400-35MM_S.JPG',
 '101_NIKON-D3400-35MM_S.JPG',
 '102_NIKON-D3400-35MM_S.JPG',
 '103_HUAWEI-P20_S.jpg',
 '104_IPHONE-SE_S.jpg',
 '105_IPHONE-SE_S.jpg',
 '106_NIKON-D3400-35MM_S.JPG',
 '107_XIAOMI-MI8-SE_S.jpg',
 '108_XIAOMI-MI8-SE_S.jpg',
 '109_HONOR-7X_S.jpg',
 '10_ASUS-ZENFONE-LIVE-ZB501KL_S.jpg',
 '110_IPHONE-7_S.jpeg',
 '111_IPHONE-7_S.jpeg',
 '112_NIKON-D3400-35MM_S.JPG',
 '113_SAMSUNG-GALAXY-A5_S.jpg',
 '114_ASUS-ZE500KL_S.jpg',
 '115_NIKON-D3400-35MM_S.JPG',
 '116_BQ-5512L_S.jpg',
 '117_HONOR-7X_S.jpg',
 '118_HONOR-7X_S.jpg',
 '119_HONOR-7X_S.jpg',
 '11_XIAOMI-MI8-SE_S.jpg',
 '120_HONOR-7X_S.jpg',
 '121_HONOR-7X_S.jpg',
 '122_HONOR-7X_S.jpg',
 '123_NIKON-D3400-35MM_S.JPG',
 '124_HONOR-7X_S.jpg',
 '125_NIKON-D3400-35MM_S.JPG',
 '126_NIKON-D3400-18-55MM_S.JPG',
 '127_IPHONE-8_S.jpeg',
 '128_XIAOMI-MI8-SE_S.jpg',
 '129_NIKON-D3400-18-55MM_S.JPG',
 '12_SAMSUNG-GALAXY-J5_S.jpg',
 '130_NIKON-D3400-18-55MM_S.JPG',
 '131_NIKON-D3400-18-55MM_S.JPG',
 '

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
from PIL import Image

In [21]:
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms

class dataset(Dataset):
    def __init__(self, bdir, sdir):
        self.bdir = bdir
        self.sdir = sdir
        self.blurimages = sorted([os.path.join(bdir, file_name) for file_name in sorted(os.listdir(bdir))])
        self.sharpimages = sorted([os.path.join(sdir, file_name) for file_name in sorted(os.listdir(sdir))])

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256))
            
        ])
    def __getitem__(self, index):
        
        blur_image=self.transform(Image.open(self.blurimages[index]))
        sharp_image=self.transform(Image.open(self.sharpimages[index]))
        return (blur_image,sharp_image)
    
    def __len__(self):
        return len(self.blurimages)


In [22]:
dataset = dataset(bdir='./blur_dataset/motion_blurred', sdir='./blur_dataset/sharp')

In [23]:
from torch.utils.data import random_split
train_len=int(0.7*len(dataset))
val_len=int(len(dataset))-train_len
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])

train_dataloader = DataLoader(train_dataset, batch_size=5)
test_dataloader = DataLoader(val_dataset, batch_size=5)


In [24]:

import numpy as np
import matplotlib.pyplot as plt
Epochs=5
model = UNet()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
losses = []
for epoch in range(0, Epochs):
  running_loss = 0
  for i_batch, batch in enumerate(train_dataloader):
    model.train(True)
    optimizer.zero_grad()
    output = model(batch[0])
    loss = criterion(batch[1], output)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()
    
  losses.append(running_loss)
  print(f"Epoch {epoch+1} : ", running_loss)
plt.plot(losses)
plt.title('Loss(Square of L2 norm distance) Decay')
plt.show()

Epoch 1 :  3.725837606936693
