In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.init as init
from math import log10
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image, ImageFilter
from os import listdir
from os.path import join
from tqdm import tqdm           #display loop evolution

In [2]:
#Data

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('RGB')
    return img

CROP_SIZE = 32

class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, scale_factor, with_bicubic_upsampling = True):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        crop_size = CROP_SIZE - (CROP_SIZE % scale_factor) # Valid crop size
        
        if with_bicubic_upsampling:
            self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # cropping the image
                                        transforms.Resize(crop_size//scale_factor),  # subsampling the image (half size)
                                        transforms.ToTensor()])
        else:
            self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # cropping the image
                                        transforms.Resize(crop_size//scale_factor),  # subsampling the image (half size)
                                        transforms.ToTensor()])
                
        self.target_transform = transforms.Compose([transforms.CenterCrop(crop_size), # since it's the target, we keep its original quality
                                       transforms.ToTensor()])

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        
        input = input.filter(ImageFilter.GaussianBlur(1)) 
        input = self.input_transform(input)
        target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)

In [3]:
#Model

###Classes of Generator and Discriminator

class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super(Generator, self).__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False, use_act=True)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_bn=True, use_act=False)
        self.upsamples = nn.Sequential(UpSampleBlock(num_channels, 2), UpSampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
    
    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = initial + self.convblock(x)
        x = self.upsamples(x)
        x = self.final(x)
        return torch.tanh(x)
    
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super(Discriminator, self).__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(ConvBlock(in_channels, feature, kernel_size=3, stride=1+idx%2, padding=1, discriminator=True,
                                    use_act=True, use_bn=False if idx==0 else True))
            in_channels = feature
        
        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(features[-1]*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.blocks(x)
        x = self.classifier(x)
        return x #torch.sigmoid(x)  ##sigmoid in paper! ...



####Classes of model's blocks

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, discriminator=False, use_act=True, use_bn=True, **kwargs):
        super(ConvBlock, self).__init__()
        self.use_act = use_act
        self.use_bn = use_bn
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels)
    
    def forward(self, x):
        x = self.cnn(x)
        x = self.bn(x) if self.use_bn else x
        x = self.act(x) if self.use_act else x
        return x

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpSampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels*scale_factor**2, kernel_size=3, stride=1, padding=1)
        self.ps = nn.PixelShuffle(scale_factor)
        self.act = nn.PReLU(num_parameters=in_channels)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.ps(x)
        return self.act(x)
    
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block1 = ConvBlock(in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=True, use_bn=True)
        self.block2 = ConvBlock(in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False, use_bn=True)

    def forward(self, x):
        x1 = self.block1(x)
        x2 = self.block2(x1)
        return x1 + x2


In [4]:
#VGG-Loss
from turtle import forward
from torchvision.models import vgg19        #VGG19 to match the paper
cuda = True
device = torch.device("cuda:0" if (torch.cuda.is_available() and cuda) else "cpu")

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = vgg19(weights='DEFAULT').features[:36].eval().to(device)   #36 to match phi5,4 as in the original paper 
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)

In [5]:
#Main

# Parameters
BATCH_SIZE = 4
NUM_WORKERS = 0 # on Windows, set this variable to 0
LEARNING_RATE = 1e-4
scale_factor = 4
nb_epochs = 15
cuda = True
device = torch.device("cuda:0" if (torch.cuda.is_available() and cuda) else "cpu")
torch.manual_seed(0)
torch.cuda.manual_seed(0)


trainset = DatasetFromFolder("data/train", scale_factor=scale_factor)
testset = DatasetFromFolder("data/test", scale_factor=scale_factor)

trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
testloader = DataLoader(dataset=testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

generator = Generator(in_channels=3).to(device)                 #Working on R,G,B channels
discriminator = Discriminator(in_channels=3).to(device)         #Working on R,G,B channels

optimizer_generator = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))

mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()


hist_loss_train = []
hist_loss_test = []
hist_psnr_train = []
hist_psnr_test = []
for epoch in range(nb_epochs):
    # Train
    avg_psnr = 0
    epoch_loss = 0
    for idx, (low_res, high_res) in enumerate(tqdm(trainloader)):
        high_res = high_res.to(device)
        low_res = low_res.to(device)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z))) (bce-logit loss)
        fake_image_generator = generator(low_res)
        output_real_discriminator = discriminator(high_res)
        output_fake_discriminator = discriminator(fake_image_generator.detach())

        discriminator_loss_real = bce(output_real_discriminator, torch.ones_like(output_real_discriminator) - 0.1*torch.rand_like(output_real_discriminator))
        discriminator_loss_fake = bce(output_fake_discriminator, torch.zeros_like(output_fake_discriminator))
        discriminator_loss = discriminator_loss_real + discriminator_loss_fake

        optimizer_discriminator.zero_grad()
        discriminator_loss.backward()
        optimizer_discriminator.step()


        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output_fake_discriminator = discriminator(fake_image_generator)

        adversarial_loss = 1e-3 * bce(output_fake_discriminator, torch.ones_like(output_fake_discriminator))     #in paper also explor l2_loss = mse(fake_image_generator, high_res)
        loss_for_vgg = 0.006 * vgg_loss(fake_image_generator, high_res)
        generator_loss = adversarial_loss + loss_for_vgg

        optimizer_generator.zero_grad()
        generator_loss.backward()
        optimizer_generator.step()

        ##### loss & psnr update
        epoch_loss += generator_loss.item()
        avg_psnr += 10 * log10(1 / generator_loss.item())

    ### loss & psnr train and test
    hist_loss_train.append(epoch_loss / len(trainloader))
    hist_psnr_train.append(avg_psnr / len(trainloader))

    avg_psnr = 0
    epoch_loss = 0
    with torch.no_grad():
        for batch in testloader:
            input, target = batch[0].to(device), batch[1].to(device)

            fake_image_generator = generator(input)
            output_real_discriminator = discriminator(target)
            output_fake_discriminator = discriminator(fake_image_generator.detach())

            discriminator_loss_real = bce(output_real_discriminator, torch.ones_like(output_real_discriminator) - 0.1*torch.rand_like(output_real_discriminator))
            discriminator_loss_fake = bce(output_fake_discriminator, torch.zeros_like(output_fake_discriminator))
            discriminator_loss = discriminator_loss_real + discriminator_loss_fake

            output_fake_discriminator = discriminator(fake_image_generator)

            adversarial_loss = 1e-3 * bce(output_fake_discriminator, torch.ones_like(output_fake_discriminator))     #in paper also explor l2_loss = mse(fake_image_generator, target)
            loss_for_vgg = 0.006 * vgg_loss(fake_image_generator, target)
            generator_loss = adversarial_loss + loss_for_vgg

            epoch_loss += generator_loss.item()
            avg_psnr += 10 * log10(1 / generator_loss.item())

    print(f"Average PSNR: {avg_psnr / len(testloader)} dB.")
    hist_loss_test.append(epoch_loss / len(testloader))
    hist_psnr_test.append(avg_psnr / len(testloader))



100%|██████████| 63/63 [01:49<00:00,  1.73s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 21.192325027306953 dB.


100%|██████████| 63/63 [01:43<00:00,  1.65s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 24.29467137150114 dB.


100%|██████████| 63/63 [01:37<00:00,  1.55s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.139043161349484 dB.


100%|██████████| 63/63 [01:26<00:00,  1.37s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.17599808513599 dB.


100%|██████████| 63/63 [01:34<00:00,  1.50s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 20.86695291307029 dB.


100%|██████████| 63/63 [01:25<00:00,  1.36s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.53980119071621 dB.


100%|██████████| 63/63 [01:34<00:00,  1.50s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.81282126493524 dB.


100%|██████████| 63/63 [01:21<00:00,  1.29s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 21.70059560689062 dB.


100%|██████████| 63/63 [01:16<00:00,  1.21s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.565672676147912 dB.


100%|██████████| 63/63 [01:30<00:00,  1.43s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 21.865001353147797 dB.


100%|██████████| 63/63 [01:42<00:00,  1.63s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 24.979677315668404 dB.


100%|██████████| 63/63 [01:19<00:00,  1.26s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 24.95749785022448 dB.


100%|██████████| 63/63 [01:19<00:00,  1.26s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.183417542697068 dB.


100%|██████████| 63/63 [01:32<00:00,  1.47s/it]
  0%|          | 0/63 [00:00<?, ?it/s]

Average PSNR: 23.596424163593085 dB.


100%|██████████| 63/63 [01:23<00:00,  1.33s/it]


Average PSNR: 21.942243409689805 dB.


In [60]:
##DO NOT RUN CRASH
_, ax = plt.subplots(1,2)
ax[0].plot(hist_loss_train, label='train loss', c='b')
ax[0].plot(hist_loss_test, label='test loss', c='r')
ax[0].legend()
ax[1].plot(hist_psnr_train, label='train psnr', c='b', linestyle='--')
ax[1].plot(hist_psnr_test, label='test psnr', c='r', linestyle='--')
ax[1].legend()
_.set_size_inches(12,5)

In [11]:
k = np.random.randint(0,13)
for idx, (test_features, test_labels) in enumerate(testloader):
    if idx != k: continue;

    crop_size = CROP_SIZE - (CROP_SIZE % scale_factor) # Valid crop size
    crop = transforms.CenterCrop(crop_size)

    LR_original = crop(test_labels[0]).squeeze()
    LR = test_features[0].squeeze()     #Low Resolution (bicubiced) image
    GT = test_labels[0].squeeze()      #Ground Truth
    HR = (generator(test_features.to(device)).cpu().squeeze()[0].detach().numpy()*255.0).clip(0,255)        #High Resolution SRCNN image

    _, ax = plt.subplots(1,4, gridspec_kw={'width_ratios': [1, scale_factor, scale_factor, scale_factor]})
    ax[0].imshow(LR_original, cmap="gray")
    ax[1].imshow(LR, cmap="gray")
    ax[2].imshow(HR, cmap="gray")   
    ax[3].imshow(GT, cmap="gray")
    ax[0].title.set_text('LR image')
    ax[1].title.set_text('LR-bicubiced image')
    ax[2].title.set_text('HR image (SRCNN)')
    ax[3].title.set_text('HR image (ground truth)')

    _.set_size_inches(13,13)
    for i in range(4): ax[i].set_axis_off()
    plt.show()
    #print(f"Label: {label}")

: 

: 