In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#unziping train dataest
!mkdir data
!mkdir data/output
!mkdir data/model
!mkdir /content/data/data
!unzip /content/drive/MyDrive/DIV2K_train_HR.zip -d /content/data/data

In [None]:
#unzipping test dataset
!unzip /content/drive/MyDrive/set5.zip -d /content/data/

In [None]:
!unzip /content/drive/MyDrive/set14.zip -d /content/data/

In [6]:
import PIL.Image as pil_image

#set5 testset preparation


image_link = "/content/data/set5"

scale = 4
for i in range (1,6):
    #open image
    image_file = image_link +"/s" +str(i)+ ".png"
    image = pil_image.open(image_file)
    image_width = image.width
    image_height = image.height

    #create low resolution images
    image = image.resize((image.width//scale, image.height//scale), resample=pil_image.BICUBIC)
    image.save( image_link +"/s"+ str(i)+"lr" + str(scale) + ".png")

In [7]:
import PIL.Image as pil_image

#set14 testset preparation


image_link = "/content/data/set14"

scale = 4
for i in range (1,15):
    #open image
    image_file = image_link +"/ss" +str(i)+ ".png"
    image = pil_image.open(image_file)
    image_width = image.width
    image_height = image.height

    #create low resolution images
    image = image.resize((image.width//scale, image.height//scale), resample=pil_image.BICUBIC)
    image.save( image_link +"/ss"+ str(i)+"lr" + str(scale) + ".png")

In [None]:
import PIL.Image as pil_image
import numpy as np
from torchvision import transforms

#Create Empty array
hrarr = []
lrarr = []
image_link = "/content/data/data/DIV2K_train_HR/0"

#Random transform ( Rotation & Crop & Flip)
crop_transform = transforms.Compose([
    transforms.RandomRotation(degrees = (0,360)),
    transforms.RandomCrop((96,96)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomHorizontalFlip(p=0.5)
       ])

for i in range (0,800):
    if(i>98):
      hrimage = pil_image.open(image_link + str(i+1) + ".png")
    elif(i>8):
      hrimage = pil_image.open(image_link + "0" + str(i+1) + ".png")
    else:
      hrimage = pil_image.open(image_link + "00" + str(i+1) + ".png")

    for _ in range(0,10):
        #Transform & Crop Image
        hrcropped = crop_transform(hrimage) 
        hrcroppedimage = np.array(hrcropped)

        #Create Low-Resolution Image
        lrcropped = hrcropped.resize((24,24), resample=pil_image.BICUBIC)
        lrcroppedimage = np.array(lrcropped)

        #Append patches to array
        hrarr.append(hrcroppedimage) 
        lrarr.append(lrcroppedimage)

print(len(hrarr))

In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from PIL import Image
from numpy import asarray
import random
import PIL
import albumentations as at
from torchvision.models.feature_extraction import create_feature_extractor
import torchvision.models as models


#Return psnr value between torch.tensor images
def psnr_between_rgb(img1,img2):
    if len(img1.shape) == 4:
        img1 = img1.squeeze(0)
        img2 = img2.squeeze(0)
    y1 = 16. + (64.738 * img1[0, :, :] + 129.057 * img1[1, :, :] + 25.064 * img1[2, :, :]) / 256.
    y2 = 16. + (64.738 * img2[0, :, :] + 129.057 * img2[1, :, :] + 25.064 * img2[2, :, :]) / 256.
    psnr = (10. * torch.log10(1. / torch.mean((y1 - y2) ** 2)))
    return psnr


#Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3,64,kernel_size=3,padding=1),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer1[0].weight)
        
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3, stride = 2, padding=1, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer2[0].weight)

        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 128, kernel_size=3, stride = 1, padding=1, bias=False),
            torch.nn.BatchNorm2d(128),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer3[0].weight)

        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 128, kernel_size=3, stride = 2, padding=1, bias=False),
            torch.nn.BatchNorm2d(128),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer4[0].weight)

        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv2d(128, 256, kernel_size=3, stride = 1, padding=1, bias=False),
            torch.nn.BatchNorm2d(256),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer5[0].weight)

        self.layer6 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, stride = 2, padding=1, bias=False),
            torch.nn.BatchNorm2d(256),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer6[0].weight)

        self.layer7 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 512, kernel_size=3, stride = 1, padding=1, bias=False),
            torch.nn.BatchNorm2d(512),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer7[0].weight)

        self.layer8 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 512, kernel_size=3, stride = 2, padding=1, bias=False),
            torch.nn.BatchNorm2d(512),
            torch.nn.LeakyReLU(0.2, True),
        )
        torch.nn.init.kaiming_normal_(self.layer8[0].weight)
        
        self.layer9 = torch.nn.Sequential(
            torch.nn.Linear(512 * 6 * 6, 1024),
            torch.nn.LeakyReLU(0.2, True),
            torch.nn.Linear(1024, 1),
            torch.nn.Sigmoid()
        )
        torch.nn.init.kaiming_normal_(self.layer9[0].weight)
        torch.nn.init.kaiming_normal_(self.layer9[2].weight)

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        out = self.layer8(out)
        out = torch.flatten(out, 1)
        out = self.layer9(out)
        return out


#VGG-19 model
class vgg(nn.Module):
    def __init__(self):
        super(vgg,self).__init__()

        model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
        self.feature_extractor = create_feature_extractor(model, ["features.35"])
        self.feature_extractor.eval()

        self.normalize = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])

        for model_parameters in self.feature_extractor.parameters():
            model_parameters.requires_grad = False
        self.content_criterion = torch.nn.MSELoss()

    def forward(self, sr, hr):

        sr = self.normalize(sr)
        hr = self.normalize(hr)

        sr = self.feature_extractor(sr)["features.35"]
        hr = self.feature_extractor(hr)["features.35"]

        return self.content_criterion(sr, hr)


# Residual block for Generator
class Residual(nn.Module):
    
    def __init__(self):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, padding=1,bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1,bias = False)
        self.bn2 = nn.BatchNorm2d(64)

    def forward(self, x):
        residual = x
        out = self.conv1(residual)
        out = self.bn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return out + residual


# Generator model
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator,self).__init__()

        self.block1 = nn.Sequential( 
            nn.Conv2d(3,64,kernel_size=9, padding =4),
            nn.PReLU()
        )

        self.residuals = self.stack_residual(Residual,16)
        
        self.block3 = nn.Sequential(
            nn.Conv2d(64,64, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(64)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size = 3, padding = 1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        
        self.block5 = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size =3, padding = 1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )

        self.block6 = nn.Conv2d(64,3,kernel_size = 9,padding = 4)

        self.initialize_weights()

    def stack_residual(self,block,num_layer):
        layers = []
        for i in range(num_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)

    def forward(self,x):
        residual = self.block1(x)
        out = self.residuals(residual)
        out = torch.add(self.block3(out),residual)
        out = self.block4(out)
        out = self.block5(out)
        out = self.block6(out)
        return out



####################################################################################
######## Mode
###      Train new SRResNet      = 1
###   Train existing SRResNet    = 2
###         Train SRGAN          = 3
###        Download output       = 4



mode = 1



##########  MODEL
model_parameters_path = "/content/data/model/model.pth"
d_model_parameters_path = "/content/data/model/d_modelpth"

######### PARAMETER
learning_rate = 0.0001
training_epochs = 1
batchsize = 16
userseed = 123

######### FILE ROUTES
sub_image_size = 96
scale = 4
image_link = "/content/data/data/DIV2K_train_HR/0"
test_image_link = "/content/data/set14/ss11lr4.png"
test_image_output_link = "/content/data/output/output11.png"

#####################################################################################


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(userseed)

if device == 'cuda':
    torch.cuda.manual_seed_all(userseed)

#models
model = Generator().to(device)
d_model = Discriminator().to(device)
vgg_model = vgg().to(device)

#loading model state
if mode ==2 or mode ==3 or mode==4:
    model.load_state_dict(torch.load(model_parameters_path))
    d_model.load_state_dict(torch.load(d_model_parameters_path))


criterion = nn.MSELoss()
d_criterion = nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
d_optimizer = torch.optim.Adam(d_model.parameters(), lr = learning_rate)

class TrainDataset(Dataset):
    def __init__(self):
        super(TrainDataset,self).__init__()
    def __getitem__(self,idx):
        return lrarr[idx].transpose((2,0,1)).astype(np.float32)/255. ,hrarr[idx].transpose((2,0,1)).astype(np.float32)/255.
    def __len__(self):
        return len(hrarr) 

train_dataset = TrainDataset()

train_dataloader = DataLoader(dataset = train_dataset,
                 batch_size = batchsize,
                 shuffle = True,
                 num_workers = 2,
                 pin_memory=True,
                 drop_last = True)

#train SRResnet
if mode ==1 or mode==2:

  for epoch in range(training_epochs):
        model.train()
        avg_cost=0
        for data in train_dataloader:

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            preds = model(inputs)
        
            loss = criterion(preds, labels)     
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()
            avg_cost += loss / batchsize


        print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))
        model.eval()
        sum_psnr = 0

        for i in range(1,6):

          setlrlr = Image.open("/content/data/set5/s" + str(i) + "lr4.png")
          setlrlr = np.array(asarray(setlrlr)).astype(np.float32)
          setlrlr = setlrlr.transpose((2,0,1))
          setlrlr /= 255.
          setlrlr = torch.from_numpy(setlrlr).to(device)
          setlrlr = setlrlr.unsqueeze(0)
          

          sethrhr = Image.open("/content/data/set5/s"+str(i)+".png")
          sethrhr= np.array(asarray(sethrhr)).astype(np.float32)
          sethrhr = sethrhr.transpose((2,0,1))
          sethrhr /= 255.
          sethrhr = torch.from_numpy(sethrhr).to(device)
          sethrhr = sethrhr.unsqueeze(0)

          with torch.no_grad():
              setpreds = model(setlrlr).clamp(0.0,1.0)
          
       
          psnr = psnr_between_rgb(sethrhr,setpreds)
          sum_psnr+=psnr
          
        print('PSNR: {:.2f}'.format(sum_psnr/5))

#Train SRGAN
elif mode ==3:
    for epoch in range(training_epochs):
        model.train()
        d_model.train()
        avg_d_cost=0
        avg_g_cost = 0

        for data in train_dataloader:

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
    
            #Labels
            real_label = torch.full([batchsize,1], 1.0).to(device)
            fake_label = torch.full([batchsize,1], 0.0).to(device)



            #### Generator
            for d_parameters in d_model.parameters():
                d_parameters.requires_grad = False

            d_optimizer.zero_grad()
            optimizer.zero_grad()

            sr = model(inputs)
            ad_loss = d_criterion(d_model(sr), real_label)

            vgg_loss = vgg_model(sr,labels)
            perceptual_loss = 0.06*vgg_loss + 0.001*ad_loss     
            perceptual_loss.backward()
            optimizer.step()

            for d_parameters in d_model.parameters():
                d_parameters.requires_grad = True



            ### Discriminator
            d_optimizer.zero_grad()
            optimizer.zero_grad()

            d_real_output = d_model(labels)
            d_real_loss = d_criterion(d_real_output,real_label)

            
            sr = model(inputs)
            d_fake_output = d_model(sr)
            d_fake_loss = d_criterion(d_fake_output, fake_label)


            d_loss = d_real_loss + d_fake_loss
            d_loss = d_loss/2.0
            d_loss.backward()
            d_optimizer.step()

            avg_d_cost += d_loss / batchsize
            avg_g_cost += perceptual_loss / batchsize

        
        print('[Epoch: {:>4}] d_cost = {:>.9} g_cost = {:>.9}'.format(epoch + 1, avg_d_cost, avg_g_cost))

    
        model.eval()
              


##generate test image
elif mode==4:
    
    model.eval()

    setlrlr = Image.open(test_image_link)
    setlrlr= np.array(asarray(setlrlr)).astype(np.float32)
    setlrlr = setlrlr.transpose((2,0,1))
    setlrlr /= 255.
    setlrlr = torch.from_numpy(setlrlr).to(device)
    setlrlr = setlrlr.unsqueeze(0)
    with torch.no_grad():
        setpreds = model(setlrlr).clamp(0.0,1.0)

    #save part 
    preds2 = setpreds.mul(255.0).cpu().numpy()
    preds2 = preds2.squeeze()
    preds2 = preds2.transpose((1,2,0))
    preds2 = np.clip(preds2, 0.0,255.0).astype(np.uint8)
    output = pil_image.fromarray(preds2)
    output.save(test_image_output_link)


print("done")
torch.save(model.state_dict(), model_parameters_path)
torch.save(d_model.state_dict(), d_model_parameters_path)