In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import sys
py_file_location = "/content/drive/My Drive/Image Super Resolution/SRGAN/"
sys.path.append(os.path.abspath(py_file_location))
from dataset import DIV2KDataset
from model2 import SRResNet

In [None]:
import time
from enum import Enum
from copy import deepcopy
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transform
from torchsummary import summary
from torchvision.transforms import functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from PIL import Image
from tqdm import tqdm
import cv2
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

In [None]:
train_dataset = DIV2KDataset(0)
train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle=True, num_workers=2, pin_memory=True)

valid_dataset = DIV2KDataset(1)
valid_dataloader = DataLoader(valid_dataset,batch_size = 1, shuffle=False, num_workers=2, pin_memory=True)

test_dataset = DIV2KDataset(2)
test_dataloader = DataLoader(test_dataset,batch_size = 1, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
model = SRResNet().to(device)
epochs = 100

In [None]:
summary(model, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]          15,616
             PReLU-2         [-1, 64, 128, 128]               1
ConvolutionalBlock-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
       BatchNorm2d-5         [-1, 64, 128, 128]             128
             PReLU-6         [-1, 64, 128, 128]               1
ConvolutionalBlock-7         [-1, 64, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          36,928
       BatchNorm2d-9         [-1, 64, 128, 128]             128
ConvolutionalBlock-10         [-1, 64, 128, 128]               0
    ResidualBlock-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,928
      BatchNorm2d-13         [-1, 64, 128, 128]             128
            PReLU-14         [-1, 64, 

In [None]:
def init_weights(model):
  if isinstance(model, torch.nn.Linear) or isinstance(model, torch.nn.Conv2d):
    #torch.nn.init.kaiming_uniform_(model.weight,mode='fan_in', nonlinearity='relu')
    torch.nn.init.xavier_uniform_(model.weight)
    #model.bias.data.fill_(0.01)

In [None]:
model.apply(init_weights)

In [None]:
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=1e-4)

In [None]:
criterion = nn.MSELoss().to(device)
psnr_criterion = nn.MSELoss().to(device)

In [None]:
scheduler = lr_scheduler.StepLR(optimizer, epochs // 4, 0.1) # epochs // 2
scaler = amp.GradScaler()

In [None]:
from torchvision.transforms import transforms
def validate(model, loader, criterion): 

  output_dir = "/content/drive/My Drive/Image Super Resolution/SRGAN/output_srresnet/"
  img_idx = 0

  with torch.no_grad():
    model.eval()
    psnr = 0
    ssim = 0
    val_loss= 0
    with tqdm(loader, unit='images') as validation:
      
      validation.set_description("Validation")
      for lr_batch, hr_batch in validation:
        img_idx += 1
        hr_img, lr_img = hr_batch.to(device), lr_batch.to(device)
        

        with amp.autocast():
          hr_pred = model(lr_img)
          loss = criterion(hr_pred, hr_img)

        val_loss += loss.item() / len(loader)

        hr_pred_img = hr_pred.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_pred_img = hr_pred_img.astype(np.float32)
        hr_y_image = np.dot(hr_pred_img, [65.481, 128.553, 24.966]) + 16.0
        hr_y_image /= 255.
        hr_y_image = hr_y_image.astype(np.float32)

        #hr_gt_img = hr_img.half()
        hr_gt_img = hr_img.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_gt_img = hr_gt_img.astype(np.float32)
        hr_gt_y_image = np.dot(hr_gt_img, [65.481, 128.553, 24.966]) + 16.0
        hr_gt_y_image /= 255.
        hr_gt_y_image = hr_gt_y_image.astype(np.float32)


        psnr += peak_signal_noise_ratio(hr_gt_img, hr_pred_img,data_range=255.) / len(loader)
        ssim += structural_similarity(hr_gt_img, hr_pred_img,data_range=255., multichannel= True) / len(loader)

        validation.set_postfix(Val_Loss = val_loss, Val_PSNR = psnr, Val_SSIM = ssim)

        with amp.autocast():
          hr_pred = model(lr_img).clamp(0.0, 1.0).to(device)

          hr_pred_img = hr_pred.squeeze(0).cpu().numpy()
          image = hr_pred_img.transpose((1,2,0))
          img = (image * 255.).astype("uint8")
          path = output_dir + str(img_idx) + '.jpg'
          cv2.imwrite(path, img)    
      print('\n')


    return val_loss, psnr, ssim

#validate(model,valid_dataloader, criterion)

In [None]:
resume_training = True
prev_epoch = 0 
PATH = "/content/drive/My Drive/Image Super Resolution/SRGAN/srresnet.pth.tar"
#PATH = "/content/drive/My Drive/Image Super Resolution/SRGAN/srresnet-models/model3-norm/srresnet.pth.tar"
loss_list = list()
psnr_list = list()

val_ssim_list = list()
val_psnr_list = list()
val_loss_list = list()

if resume_training and os.path.exists(PATH):

  checkpoint = torch.load(PATH)

  #model.load_state_dict(checkpoint['state_dict'])
  model = checkpoint['model']

  optimizer.load_state_dict(checkpoint['optimizer'])
  scheduler.load_state_dict(checkpoint['scheduler'])
 
  prev_epoch = deepcopy(checkpoint['epoch'])
  loss_list = deepcopy(checkpoint['train_loss_list'])
  psnr_list = deepcopy(checkpoint['psnr_list'])

  val_ssim_list = deepcopy(checkpoint['val_ssim_list'])
  val_psnr_list = deepcopy(checkpoint['val_psnr_list'])
  val_loss_list = deepcopy(checkpoint['val_loss_list'])
 
  print("Continue training from previous checkpoints ...")

for epoch in range(prev_epoch, epochs):
  start = time.time()
  epoch_loss = 0
  model.train()
  with tqdm(train_dataloader, unit='batch') as tepoch:
    tepoch.set_description(f"Epoch {epoch+1:4d}/{epochs}")

    for lr_batch, hr_batch in tepoch:
        hr_img, lr_img = hr_batch.to(device), lr_batch.to(device)
        
        optimizer.zero_grad()
        with amp.autocast():
          hr_pred = model(lr_img)
          loss = criterion(hr_pred, hr_img)

          scaler.scale(loss).backward()
          scaler.step(optimizer)
          scaler.update()

        epoch_loss += loss.item() / len(train_dataloader)

        tepoch.set_postfix(Loss = epoch_loss)
  
  val_loss, val_psnr, val_ssim = validate(model,valid_dataloader, criterion)

  del hr_img, lr_img,hr_pred
  torch.cuda.empty_cache()
  

  val_psnr_list.append(val_psnr)
  val_ssim_list.append(val_ssim)
  val_loss_list.append(val_loss)

  loss_list.append(epoch_loss)
  psnr_list.append(epoch_psnr)

  scheduler.step()

  torch.save({"model": model,
          "epoch": epoch + 1,
          "psnr_list": psnr_list,
          "state_dict": model.state_dict(),
          "optimizer": optimizer.state_dict(),
          "scheduler": scheduler.state_dict(),
          "train_loss_list": loss_list,
          "val_loss_list": val_loss_list,
          "val_ssim_list": val_ssim_list,
          "val_psnr_list": val_psnr_list}, PATH)
