In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import PIL
import ignite.metrics
import numpy as np





In [None]:
import pandas as pd
annotations_file = pd.read_csv("../input/the-oxfordiiit-pet-dataset/annotations/annotations/list.txt",
                               sep = ' ',
                               on_bad_lines='skip')

In [None]:
annotations_file.head(20)

In [None]:
annotations_file = annotations_file.drop(labels=range(3), axis=0)
annotations_file.head(20)

In [None]:
len(annotations_file)

In [None]:
annotations_file.iloc[0,0]


In [None]:
dir_path = '../input/the-oxfordiiit-pet-dataset/images/images/'


In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self,annotations_file,img_dir,transform = None):
        self.annotations = annotations_file
        self.img_dir = img_dir
        self.transforms = transform
        
    #len simply returns the length of the dataset    
    def __len__(self):
        return len(self.annotations)
    
    #getitem returns the (x,y) pair at a particular index
    def __getitem__(self,index):
        img_path = f"{self.img_dir}/{self.annotations.iloc[index,0]}.jpg"
        img = PIL.Image.open(img_path).convert('RGB')
      
        label = torch.tensor(int(self.annotations.iloc[index,1]))
        
        if self.transforms:
            high_res_image = self.transforms[0](img)
            low_res_image = self.transforms[1](img)
            
        return (high_res_image,low_res_image),label

In [None]:
class res_block(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size = kernel_size,stride = stride,
                                             padding = kernel_size // 2),
                                   nn.BatchNorm2d(out_channels),
                                   nn.PReLU(out_channels),
                                   nn.Conv2d(in_channels,out_channels,kernel_size = kernel_size,stride = stride,
                                             padding = kernel_size // 2),
                                   nn.BatchNorm2d(out_channels))
    
    def forward(self,x):
        return self.block(x)
                                   
        

In [None]:
class p_block(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride,scale):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size = kernel_size,stride = stride,padding = kernel_size // 2),
                                   nn.PixelShuffle(scale),
                                   nn.PReLU(out_channels // (scale**2)))
                                   
    
    def forward(self,x):
        return self.block(x)

In [None]:
class SRResNet(pl.LightningModule):
    def __init__(self,lr,img_channels,img_size,depth,shuffle_scale,loss_type = 'MSE'):
        super().__init__()
        self.save_hyperparameters()      
        
        self.conv1 = nn.Sequential(nn.Conv2d(img_channels,64,kernel_size = 9, stride = 1,padding = 4),
                                   nn.PReLU(64))
        
        self.residual_blocks = nn.ModuleList([
                               res_block(64,64,3,1)
                            
                               for _ in range(depth)
        ])
        
        self.conv2 = nn.Sequential(nn.Conv2d(64,64,kernel_size = 3,stride = 1,padding = 3//2),
                                   nn.BatchNorm2d(64))
        
        self.pshuffle = nn.ModuleList(
                                       [
                                           p_block(64,256,3,1,2),
                                           p_block(64,256,3,1,2)                                           
                                       ]
        )
        
        self.conv3 = nn.Conv2d(64,3,kernel_size = 9,stride = 1, padding = 9//2)
        
    def forward(self,x):    
        N,C,H,W = x.shape
        H_out = H*2*self.hparams.shuffle_scale
        W_out = W*2*self.hparams.shuffle_scale
        
        x = self.conv1(x)
        conv1_output = x
        for block in self.residual_blocks:
            x  = block(x) + x
        
        x = self.conv2(x) + conv1_output
        
        for block in self.pshuffle:
            x  = block(x)
            
        x = self.conv3(x)
        x = x.view(N,C,H_out,W_out)     
        
        return x       
    
    def PSNR(self,img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        return -20 * torch.log10(255.0 / torch.sqrt(mse))
    
        
    def loss_fn(self,img1,img2):
        if (self.hparams.loss_type == "PSNR"):
            return self.PSNR(img1,img2)
        
        elif (self.hparams.loss_type == "MSE"):
            criterion = nn.MSELoss()
            return criterion(img1,img2)
        
        elif (self.hparams.loss_type == "MAE"):
            criterion = nn.L1Loss()
            return criterion(img1,img2)
    
    def training_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        x_hat = self(low_res_image)
        
        loss = self.loss_fn(x_hat,high_res_image)
        return loss
    
    
    
    def validation_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        x_hat = self(low_res_image)
        
        loss = self.loss_fn(x_hat,high_res_image)
        self.log("val_loss",loss)
        
    def test_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        x_hat = self(low_res_image)
        
        loss = self.loss_fn(x_hat,high_res_image)
        self.log("test_loss",loss)

        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr = self.hparams.lr)
        return optimizer
    

In [None]:

    resize_transform = transforms.Compose([
                                    transforms.Resize((320,320)),
                                    transforms.ToTensor()
                               ]
    )

    transform = transforms.Compose([
                                    transforms.Resize((80,80)),
                                    transforms.ToTensor()
                               ]
    )


    dataset = CustomDataset(annotations_file = annotations_file,
                            img_dir = dir_path ,
                            transform = [resize_transform,transform])       

    dataset_len = len(dataset)
    train_set_len = int(0.7*dataset_len)
    val_set_len = int(0.15*dataset_len)
    test_set_len  = dataset_len - train_set_len - val_set_len

    train_set,val_set,test_set = torch.utils.data.random_split(dataset,
                                                               [train_set_len,val_set_len,test_set_len],
                                                               generator=torch.Generator().manual_seed(43))

    train_loader = DataLoader(train_set,batch_size = 16)
    val_loader = DataLoader(val_set,batch_size = 16)
    test_loader = DataLoader(test_set,batch_size = 16)        


    



In [None]:
for batch in test_loader:
        (x1,x2), _ = batch
        x1= x1[0:3]
        x1.unsqueeze(0)
        grid = torchvision.utils.make_grid(x1)
        high_res_imgs = grid
        break

In [None]:
early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
                                                       patience = 3,
                                                       mode = 'min')

checkpoint = pl.callbacks.ModelCheckpoint(dirpath = 'saved_ckpts',
                                          monitor = 'val_loss')



model = SRResNet(lr = 1e-4,
              img_channels = 3,
              img_size = 80,
              depth = 12,
              shuffle_scale = 2,
              loss_type = "MAE")
trainer = pl.Trainer(gpus = 1,
                     precision = 16,
                     callbacks = [early_stopping,checkpoint],
                     max_epochs = 10)

trainer.fit(model,train_loader,val_loader)

In [None]:
model = SRResNet.load_from_checkpoint(checkpoint.best_model_path)


with torch.no_grad():
    for batch in test_loader:
        (x1,x2), _ = batch
        x2 = x2[0:3]
        output_imgs = model(x2)
        output_imgs.unsqueeze(0)
        grid = torchvision.utils.make_grid(output_imgs)
        output_imgs = grid
        break


trainer.test(model = model,dataloaders = test_loader)

In [None]:
plt.figure(figsize = (10,10))
plt.imshow(high_res_imgs.permute(1,2,0))


In [None]:
plt.figure(figsize = (10,10))
plt.imshow(output_imgs.permute(1,2,0))