In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import PIL

In [2]:
import pandas as pd
annotations_file = pd.read_csv("../input/the-oxfordiiit-pet-dataset/annotations/annotations/list.txt",
                               sep = ' ',
                               on_bad_lines='skip')

In [3]:
annotations_file.head(20)

In [4]:
annotations_file = annotations_file.drop(labels=range(3), axis=0)
annotations_file.head(20)

In [5]:
annotations_file.iloc[0,0]


In [6]:
dir_path = '../input/the-oxfordiiit-pet-dataset/images/images/'


In [7]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self,annotations_file,img_dir,transform = None):
        self.annotations = annotations_file
        self.img_dir = img_dir
        self.transforms = transform
        
    #len simply returns the length of the dataset    
    def __len__(self):
        return len(self.annotations)
    
    #getitem returns the (x,y) pair at a particular index
    def __getitem__(self,index):
        img_path = f"{self.img_dir}/{self.annotations.iloc[index,0]}.jpg"
        img = PIL.Image.open(img_path).convert('RGB')
      
        label = torch.tensor(int(self.annotations.iloc[index,1]))
        
        if self.transforms:
            high_res_image = self.transforms[0](img)
            low_res_image = self.transforms[1](img)
            
        return (high_res_image,low_res_image),label

In [8]:

resize_transform = transforms.Compose([
                                 transforms.Resize((320,320), interpolation= InterpolationMode.BILINEAR),
                                 transforms.ToTensor()
                               ]
)

transform = transforms.Compose([
                                 transforms.Resize((80,80)),
                                 transforms.Resize((320,320)),
                                 transforms.ToTensor()
                               ]
)


dataset = CustomDataset(annotations_file = annotations_file,
                          img_dir = dir_path ,
                          transform = [resize_transform,transform])

In [9]:
dataset_len = len(dataset)
train_set_len = int(0.9*dataset_len)
val_set_len = int(0.05*dataset_len)
test_set_len  = dataset_len - train_set_len - val_set_len

train_set,val_set,test_set = torch.utils.data.random_split(dataset,[train_set_len,val_set_len,test_set_len])

train_loader = DataLoader(train_set,batch_size = 16)
val_loader = DataLoader(val_set,batch_size = 16)
test_loader = DataLoader(test_set,batch_size = 16)


In [10]:
plt.imshow(train_set[6102][0][0].permute(1,2,0))

In [11]:
plt.imshow(train_set[6102][0][1].permute(1,2,0))

In [12]:
tb_high_res = pl.loggers.TensorBoardLogger("SRCNN/high_res_imgs")
tb_low_res = pl.loggers.TensorBoardLogger("SRCNN/input_imgs")
tb_resolved = pl.loggers.TensorBoardLogger("SRCNN/output_imgs")
step = 0

In [13]:
class SRCNN(pl.LightningModule):
    def __init__(self,lr):
        super().__init__()
        self.save_hyperparameters()
        self.srcnn = nn.Sequential(
                                nn.Conv2d(3,128,kernel_size = 9, stride = 1,padding = 4),nn.ReLU(),
                                nn.Conv2d(128,64,kernel_size = 5, stride = 1,padding = 2),nn.ReLU(),
                                nn.Conv2d(64,3,kernel_size = 5, stride = 1,padding = 2),nn.ReLU()
        )
        
    def forward(self,input_imgs):

        output  = self.srcnn(input_imgs)
        
        return output
    
    def training_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        y_hat = self(low_res_image)
        
        loss_fn = nn.MSELoss()
        
        loss = loss_fn(y_hat,high_res_image)
        return loss
    
    
    
    def validation_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        y_hat = self(low_res_image)
        
        loss_fn = nn.MSELoss()
        
        loss = loss_fn(y_hat,high_res_image)
        self.log("val_loss",loss)
        
    def test_step(self,batch,batch_idx):
        (high_res_image,low_res_image),_ = batch
        y_hat = self(low_res_image)
        
        loss_fn = nn.MSELoss()
        
        loss = loss_fn(y_hat,high_res_image)
        self.log("test_loss",loss)

        
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr = self.hparams.lr)
        return optimizer
    

In [14]:
def get_dataloader(interpolation_type = InterpolationMode.BILINEAR):

    
        resize_transform = transforms.Compose([
                                         transforms.Resize((320,320),interpolation = interpolation_type),
                                         transforms.ToTensor()
                                       ]
        )

        transform = transforms.Compose([
                                         transforms.Resize((80,80),interpolation = interpolation_type),
                                         transforms.Resize((320,320),interpolation = interpolation_type),
                                         transforms.ToTensor()
                                       ]
        )


        dataset = CustomDataset(annotations_file = annotations_file,
                                  img_dir = dir_path ,
                                  transform = [resize_transform,transform])       
       
        dataset_len = len(dataset)
        train_set_len = int(0.7*dataset_len)
        val_set_len = int(0.15*dataset_len)
        test_set_len  = dataset_len - train_set_len - val_set_len

        train_set,val_set,test_set = torch.utils.data.random_split(dataset,
                                                                   [train_set_len,val_set_len,test_set_len],
                                                                   generator=torch.Generator().manual_seed(42))

        train_loader = DataLoader(train_set,batch_size = 16)
        val_loader = DataLoader(val_set,batch_size = 16)
        test_loader = DataLoader(test_set,batch_size = 16)        
    
        return train_loader,val_loader,test_loader
    



In [15]:
img_dict = {}

In [16]:
train_loader,val_loader,test_loader = get_dataloader(interpolation_type = InterpolationMode.BICUBIC)
for batch in test_loader:
        (x1,x2), _ = batch
        x1= x1[0:3]
        x1.unsqueeze(0)
        grid = torchvision.utils.make_grid(x1)
        img_dict['high_res_img'] = grid
        break

In [17]:
def get_test_loss(train_loader,val_loader,test_loader,interpolation_type):
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
                                                           patience = 3,
                                                           mode = 'min')

    checkpoint = pl.callbacks.ModelCheckpoint(dirpath = f'saved_ckpts_{interpolation_type}',
                                              monitor = 'val_loss')

    tb_high_res = pl.loggers.TensorBoardLogger(f"SRCNN/high_res_imgs_{interpolation_type}")
    tb_resolved = pl.loggers.TensorBoardLogger(f"SRCNN/output_imgs_{interpolation_type}")

    model = SRCNN(8.7e-5)
    trainer = pl.Trainer(gpus = 1,
                         precision = 16,
                         callbacks = [early_stopping,checkpoint],
                         max_epochs = 1,
                         logger = [tb_high_res,tb_resolved])
    
    trainer.fit(model,train_loader,val_loader)
    model = SRCNN.load_from_checkpoint(checkpoint.best_model_path)
    
    with torch.no_grad():
        for batch in test_loader:
            (x1,x2), _ = batch
            x2 = x2[0:3]
            output_imgs = model(x2)
            output_imgs.unsqueeze(0)
            grid = torchvision.utils.make_grid(output_imgs)
            img_dict[f'{interpolation_type}'] = grid
            break
        
    
    test_dict = trainer.test(model = model,dataloaders = test_loader,ckpt_path = )
    return test_dict

In [None]:
train_loader,val_loader,test_loader = get_dataloader(interpolation_type = InterpolationMode.BILINEAR)
for batch in test_loader:
    (x1,x2), _ = batch
    x2 = x2[0:3]
    x2.unsqueeze(0)
    grid = torchvision.utils.make_grid(x2)
    plt.figure(figsize = (20,20))
    plt.imshow(grid.permute(1,2,0))
    break

In [18]:
def compare():
        interpolations = [InterpolationMode.BILINEAR,
                          InterpolationMode.NEAREST,
                          InterpolationMode.BICUBIC]
        dict_test_loss = {}
        model_paths = {}
        
        for interpolation in interpolations:
            train_loader,val_loader,test_loader = get_dataloader(interpolation_type = interpolation)
            loss_dict = get_test_loss(train_loader,val_loader,test_loader,interpolation)
            dict_test_loss[f'{interpolation}'] = loss_dict[0]['test_loss']
            
        return dict_test_loss
        


In [19]:
loss_dict = compare()

In [21]:
plt.figure(figsize = (10,10))
plt.imshow(img_dict["high_res_img"].permute(1,2,0))


In [22]:
plt.figure(figsize = (10,10))
plt.imshow(img_dict["InterpolationMode.BILINEAR"].permute(1,2,0))

In [23]:
plt.figure(figsize = (10,10))
plt.imshow(img_dict["InterpolationMode.NEAREST"].permute(1,2,0))

In [24]:
plt.figure(figsize = (10,10))
plt.imshow(img_dict["InterpolationMode.BICUBIC"].permute(1,2,0))


In [25]:
for key in loss_dict:
    print(key,loss_dict[f"{key}"])

In [26]:
plt.bar(*zip(*loss_dict.items()))
plt.show()