In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import sys
py_file_location = "/content/drive/My Drive/Image Super Resolution/SRGAN/"
sys.path.append(os.path.abspath(py_file_location))
from dataset import DIV2KDataset
from model2 import Discriminator
from model2 import SRResNet
from loss import ContentLoss, TVLoss

In [None]:
import time
from enum import Enum
from copy import deepcopy
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transform
from torchsummary import summary
from torchvision.transforms import functional as F
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from PIL import Image
from tqdm import tqdm
import cv2
import torch.backends.cudnn as cudnn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

In [None]:
train_dataset = DIV2KDataset(0)
train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle=True, num_workers=2, pin_memory=True)

valid_dataset = DIV2KDataset(1)
valid_dataloader = DataLoader(valid_dataset,batch_size = 1, shuffle=False, num_workers=2, pin_memory=True)

test_dataset = DIV2KDataset(2)
test_dataloader = DataLoader(test_dataset,batch_size = 1, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
srresnet_checkpoint = "/content/drive/My Drive/Image Super Resolution/SRGAN/srresnet.pth.tar"
generator = SRResNet().to(device)
checkpoint = torch.load(srresnet_checkpoint)
generator.load_state_dict(checkpoint['state_dict'])
generator = generator.to(device)

discriminator = Discriminator().to(device)

In [None]:
summary(generator, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]          15,616
             PReLU-2         [-1, 64, 128, 128]               1
ConvolutionalBlock-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
       BatchNorm2d-5         [-1, 64, 128, 128]             128
             PReLU-6         [-1, 64, 128, 128]               1
ConvolutionalBlock-7         [-1, 64, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          36,928
       BatchNorm2d-9         [-1, 64, 128, 128]             128
ConvolutionalBlock-10         [-1, 64, 128, 128]               0
    ResidualBlock-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,928
      BatchNorm2d-13         [-1, 64, 128, 128]             128
            PReLU-14         [-1, 64, 

In [None]:
summary(discriminator, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]           1,792
         LeakyReLU-2         [-1, 64, 512, 512]               0
ConvolutionalBlock-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 256, 256]          36,928
       BatchNorm2d-5         [-1, 64, 256, 256]             128
         LeakyReLU-6         [-1, 64, 256, 256]               0
ConvolutionalBlock-7         [-1, 64, 256, 256]               0
            Conv2d-8        [-1, 128, 256, 256]          73,856
       BatchNorm2d-9        [-1, 128, 256, 256]             256
        LeakyReLU-10        [-1, 128, 256, 256]               0
ConvolutionalBlock-11        [-1, 128, 256, 256]               0
           Conv2d-12        [-1, 128, 128, 128]         147,584
      BatchNorm2d-13        [-1, 128, 128, 128]             256
        LeakyReLU-14        [-1, 128, 

In [None]:
def init_weights(model):
  if isinstance(model, torch.nn.Linear) or isinstance(model, torch.nn.Conv2d):
    torch.nn.init.xavier_uniform_(model.weight)


discriminator.apply(init_weights)

Discriminator(
  (conv_blocks): Sequential(
    (0): ConvolutionalBlock(
      (conv_block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): ConvolutionalBlock(
      (conv_block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): ConvolutionalBlock(
      (conv_block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): ConvolutionalBlock(
      (conv_block): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e

In [None]:
g_optimizer = torch.optim.Adam(generator.parameters(),lr=1e-4)
d_optimizer = torch.optim.Adam(discriminator.parameters(),lr=1e-4)

In [None]:
content_criterion = ContentLoss().to(device)
adversarial_criterion = nn.BCEWithLogitsLoss().to(device)
criterion = nn.MSELoss().to(device)
TVLoss = TVLoss().to(device)

In [None]:
epochs = 50

d_scheduler = lr_scheduler.StepLR(d_optimizer, epochs // 4, 0.1) 
g_scheduler = lr_scheduler.StepLR(g_optimizer, epochs // 4, 0.1)

scaler = amp.GradScaler()

In [None]:
from torchvision.transforms import transforms
def validate(model, loader): 

  output_dir = "/content/drive/My Drive/Image Super Resolution/SRGAN/output_srgan/"
  img_idx = 0

  with torch.no_grad():
    model.eval()
    psnr = 0
    ssim = 0
    val_loss= 0
    with tqdm(loader, unit='images') as validation:
      validation.set_description("Validation")
      for lr_batch, hr_batch in validation:
        img_idx += 1
        hr_img, lr_img = hr_batch.to(device), lr_batch.to(device)
        with amp.autocast():
          hr_pred = model(lr_img).clamp(0.0,1.0).to(device)

        hr_pred_img = hr_pred.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_pred_img = hr_pred_img.astype(np.float32)
        hr_y_image = np.dot(hr_pred_img, [65.481, 128.553, 24.966]) + 16.0
        hr_y_image /= 255.
        hr_y_image = hr_y_image.astype(np.float32)

        hr_gt_img = hr_img.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_gt_img = hr_gt_img.astype(np.float32)
        hr_gt_y_image = np.dot(hr_gt_img, [65.481, 128.553, 24.966]) + 16.0
        hr_gt_y_image /= 255.
        hr_gt_y_image = hr_gt_y_image.astype(np.float32)

        psnr += peak_signal_noise_ratio(hr_gt_img, hr_pred_img,data_range=255.) / len(loader)
        ssim += structural_similarity(hr_gt_img, hr_pred_img,data_range=255., multichannel= True) / len(loader)

        validation.set_postfix(Val_PSNR = psnr, Val_SSIM = ssim)

        with amp.autocast():
          hr_pred = model(lr_img).clamp(0.0, 1.0).to(device)

        hr_pred_img = hr_pred.squeeze(0).cpu().numpy()
        image = hr_pred_img.transpose((1,2,0))
        img = (image * 255.).astype("uint8")
        path = output_dir + str(img_idx) + '.jpg'
        cv2.imwrite(path, img)
   
    return psnr, ssim

#validate(generator,valid_dataloader)


In [None]:
from torchvision.transforms import transforms
def test(model, loader): 
  img_idx = 0

  with torch.no_grad():
    model.eval()
    psnr = 0
    ssim = 0
    test_loss= 0
    with tqdm(loader, unit='images') as validation:
      validation.set_description("Validation")
      for lr_batch, hr_batch in validation:
        img_idx += 1
        hr_img, lr_img = hr_batch.to(device), lr_batch.to(device)
        with amp.autocast():
          hr_pred = model(lr_img)
          prob_real = discriminator(hr_img)
          prob_fake = discriminator(hr_pred)

          real_labels = torch.ones(prob_real.size()).to(device)
          fake_labels = torch.zeros(prob_fake.size()).to(device)

          g_loss_bce_real = adversarial_criterion(prob_real, real_labels)
          g_loss_bce_fake = adversarial_criterion(prob_fake, fake_labels)

          g_loss_ad = (g_loss_bce_real + g_loss_bce_fake).mean()
          g_loss_content = content_criterion(hr_pred, hr_img)     
          #g_loss_tv = TVLoss(hr_pred)       
          g_loss = g_loss_content + (0.001 * g_loss_ad) #+ (2e-8 * g_loss_tv)
          test_loss += g_loss.item() / len(loader)

        hr_pred_img = hr_pred.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_pred_img = hr_pred_img.astype(np.float32)
        hr_y_image = np.dot(hr_pred_img, [65.481, 128.553, 24.966]) + 16.0
        hr_y_image /= 255.
        hr_y_image = hr_y_image.astype(np.float32)

        hr_gt_img = hr_img.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
        hr_gt_img = hr_gt_img.astype(np.float32)
        hr_gt_y_image = np.dot(hr_gt_img, [65.481, 128.553, 24.966]) + 16.0
        hr_gt_y_image /= 255.
        hr_gt_y_image = hr_gt_y_image.astype(np.float32)

        psnr += peak_signal_noise_ratio(hr_gt_img, hr_pred_img,data_range=255.) / len(loader)
        ssim += structural_similarity(hr_gt_img, hr_pred_img,data_range=255., multichannel= True) / len(loader)

        validation.set_postfix(Val_PSNR = psnr, Val_SSIM = ssim)

        with amp.autocast():
          hr_pred = model(lr_img).clamp(0.0, 1.0).to(device)

        hr_pred_img = hr_pred.squeeze(0).cpu().numpy()
        image = hr_pred_img.transpose((1,2,0))
        img = (image * 255.).astype("uint8")
        path = output_dir + str(img_idx) + '.jpg'
        #cv2.imwrite(path, img)
   
    return test_loss, psnr, ssim

  checkpoint = torch.load(PATH)

output_dir = "/content/drive/My Drive/Image Super Resolution/SRGAN/srresnet-models/model3/test/"
#test(model,test_dataloader)
# test(model,valid_dataloader)

Validation: 100%|██████████| 200/200 [01:41<00:00,  1.96images/s, Val_PSNR=24.5, Val_SSIM=0.718]


(0.004507378927519311, 24.480883341851435, 0.7175038347075701)

In [None]:
resume_training = True
epochs = 50
PATH_G = "/content/drive/My Drive/Image Super Resolution/SRGAN/g_model.pth.tar"
PATH_D = "/content/drive/My Drive/Image Super Resolution/SRGAN/d_model.pth.tar"

d_loss_list = list()
g_loss_list = list()
psnr_list = list()

val_ssim_list = list()
val_psnr_list = list()
val_loss_list = list()

prev_epoch = 0

if resume_training and os.path.exists(PATH_G) and os.path.exists(PATH_D):

  checkpoint_G = torch.load(PATH_G)
  checkpoint_D = torch.load(PATH_D)

  generator.load_state_dict(checkpoint_G['state_dict'])
  g_optimizer.load_state_dict(checkpoint_G['optimizer'])
  g_scheduler.load_state_dict(checkpoint_G['scheduler'])

  discriminator.load_state_dict(checkpoint_D['state_dict'])
  d_optimizer.load_state_dict(checkpoint_D['optimizer'])
  d_scheduler.load_state_dict(checkpoint_D['scheduler'])
 
  prev_epoch = deepcopy(checkpoint_G['epoch'])
  d_loss_list = deepcopy(checkpoint_D['train_loss_list'])
  g_loss_list = deepcopy(checkpoint_G['train_loss_list'])
  psnr_list = deepcopy(checkpoint_G['psnr_list'])

  val_ssim_list = deepcopy(checkpoint_G['val_ssim_list'])
  val_psnr_list = deepcopy(checkpoint_G['val_psnr_list'])
  val_loss_list = deepcopy(checkpoint_G['val_loss_list'])
 
  print("Continue training from previous checkpoints ...")
  

for epoch in range(prev_epoch, epochs):
  start = time.time()
  epoch_d_loss = 0
  epoch_g_loss = 0
  epoch_psnr = 0
  
  discriminator.train()
  generator.train()
  with tqdm(train_dataloader, unit='batch') as tepoch:
    tepoch.set_description(f"Epoch {epoch+1:4d}/{epochs}")

    for lr_batch, hr_batch in tepoch:
        hr_img, lr_img = hr_batch.to(device), lr_batch.to(device)
        
        d_optimizer.zero_grad()

        for p in discriminator.parameters():
          p.requires_grad = True

        fake_hr = generator(lr_img)

        with amp.autocast():
          prob_real = discriminator(hr_img)
          prob_fake =  discriminator(fake_hr.detach())

          real_labels = torch.ones(prob_real.size()).to(device)
          fake_labels = torch.zeros(prob_fake.size()).to(device)

          d_loss_real = adversarial_criterion(prob_real, real_labels)
          d_loss_fake = adversarial_criterion(prob_fake, fake_labels)

        scaler.scale(d_loss_real).backward()
        scaler.scale(d_loss_fake).backward()
        scaler.step(d_optimizer)
        scaler.update()
        d_loss = (d_loss_real + d_loss_fake) 
        epoch_d_loss += d_loss / len(train_dataloader)

        for p in discriminator.parameters():
          p.requires_grad = False

        g_optimizer.zero_grad()
        
        with amp.autocast():
          prob_real = discriminator(hr_img)
          prob_fake = discriminator(fake_hr)

          real_labels = torch.ones(prob_real.size()).to(device)
          fake_labels = torch.zeros(prob_fake.size()).to(device)

          g_loss_bce_real = adversarial_criterion(prob_real, real_labels)
          g_loss_bce_fake = adversarial_criterion(prob_fake, fake_labels)

          g_loss_ad = (g_loss_bce_real + g_loss_bce_fake).mean()
          g_loss_content = content_criterion(fake_hr, hr_img)     
          g_loss_tv = TVLoss(fake_hr)       
          g_loss = g_loss_content + (0.001 * g_loss_ad) + (2e-8 * g_loss_tv)

        scaler.scale(g_loss).backward()
        scaler.step(g_optimizer)
        scaler.update()
        epoch_g_loss += g_loss / len(train_dataloader)


        tepoch.set_postfix(D_Loss = epoch_d_loss.item(), G_Loss = epoch_g_loss.item())
    
    val_psnr, val_ssim = validate(generator,valid_dataloader)
    print('\n')
    torch.cuda.empty_cache()
    
    val_psnr_list.append(val_psnr)
    val_ssim_list.append(val_ssim)

    d_loss_list.append(epoch_d_loss.item())
    g_loss_list.append(epoch_g_loss.item())
    psnr_list.append(epoch_psnr)

    d_scheduler.step()
    g_scheduler.step()

    torch.save({"epoch": epoch + 1,
              "psnr_list": psnr_list,
              "state_dict": discriminator.state_dict(),
              "optimizer": d_optimizer.state_dict(),
              "scheduler": d_scheduler.state_dict(),
              "train_loss_list": d_loss_list,
              "val_loss_list": val_loss_list,
              "val_ssim_list": val_ssim_list,
              "val_psnr_list": val_psnr_list}, PATH_D)
    
    torch.save({"epoch": epoch + 1,
              "psnr_list": psnr_list,
              "state_dict": generator.state_dict(),
              "optimizer": g_optimizer.state_dict(),
              "scheduler": g_scheduler.state_dict(),
              "train_loss_list": g_loss_list,
              "val_loss_list": val_loss_list,
              "val_ssim_list": val_ssim_list,
              "val_psnr_list": val_psnr_list}, PATH_G)

In [None]:
import matplotlib.pyplot as plt

def plot(PATH_G, PATH_D, generator, discriminator):
  checkpoint_G = torch.load(PATH_G)
  checkpoint_D = torch.load(PATH_D)

  prev_epoch = deepcopy(checkpoint_G['epoch'])
  d_loss_list = deepcopy(checkpoint_D['train_loss_list'])
  g_loss_list = deepcopy(checkpoint_G['train_loss_list'])

  val_ssim_list = deepcopy(checkpoint_G['val_ssim_list'])
  val_psnr_list = deepcopy(checkpoint_G['val_psnr_list'])
  val_loss_list = deepcopy(checkpoint_G['val_loss_list'])

  print(len(d_loss_list), len(g_loss_list), len(val_ssim_list), len(val_psnr_list),len(val_loss_list))
  print(val_loss_list)
  plt.xlabel("Epochs")
  plt.ylabel("Training Loss")
  plt.title("Training Loss")
  plt.plot(d_loss_list, label = 'Discriminator')
  # plt.show()
  # print("\n")  
  # plt.xlabel("Epochs")
  # plt.ylabel("Training Loss")
  # plt.title("Generator Training Loss")
  plt.plot(g_loss_list, label = 'Generator')
  #plt.plot(val_loss_list, label = 'Validation Loss')
  plt.legend()
  plt.show()
  print("\n") 

  plt.xlabel("Epochs")
  plt.ylabel("PSNR")
  plt.title("Validation PSNR")
  val_psnr_list = np.array(val_psnr_list)
  plt.plot(val_psnr_list[10:])
  plt.show()
  print("\n") 


def plot_srresnet(PATH, model):
  checkpoint = torch.load(PATH)

  loss_list = deepcopy(checkpoint['train_loss_list'])

  val_ssim_list = deepcopy(checkpoint['val_ssim_list'])
  val_psnr_list = deepcopy(checkpoint['val_psnr_list'])
  val_loss_list = deepcopy(checkpoint['val_loss_list'])
  print(loss_list)


  plt.xlabel("Epochs")
  plt.ylabel("Loss")
  plt.title("SRResNet Loss")
  plt.plot(loss_list, label = 'Training')
  plt.plot(val_loss_list, label = 'Validation')
  plt.legend()
  plt.show()
  print("\n") 

  plt.xlabel("Epochs")
  plt.ylabel("PSNR")
  plt.title("Validation PSNR")
  plt.plot(val_psnr_list)
  plt.show()
  print("\n") 

In [None]:
# PATH_G = "/content/drive/My Drive/Image Super Resolution/SRGAN/srgan-models/model3/g_model.pth.tar"
# PATH_D = "/content/drive/My Drive/Image Super Resolution/SRGAN/srgan-models/model3/d_model.pth.tar"

# plot(PATH_G, PATH_D, model, discriminator)

In [None]:
# model = SRResNet().to(device)
# PATH = "/content/drive/My Drive/Image Super Resolution/SRGAN/srresnet.pth.tar"
# plot_srresnet(PATH, model)