source: https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/esrgan

In [2]:
import os,shutil

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#Data

In [3]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)


class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

#Architecture

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math


class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)


class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x


class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

#Train

In [None]:
import argparse
import os
import numpy as np
import math
import itertools
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/images/training", exist_ok=True)
os.makedirs("./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") #0
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") #200
parser.add_argument("--dataset_name", type=str, default="DIV2K", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation") #8
parser.add_argument("--hr_height", type=int, default=256, help="high res. image height")
parser.add_argument("--hr_width", type=int, default=256, help="high res. image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving image samples")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="batch interval between model checkpoints") #5000
parser.add_argument("--residual_blocks", type=int, default=23, help="number of residual blocks in the generator")
parser.add_argument("--warmup_batches", type=int, default=500, help="number of batches with pixel-wise loss only") #500
parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight") #5 * (10 ^ -3)
parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight") # 1 * (10 ^ -2)
opt = parser.parse_args(args=[])
#print(opt)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hr_shape = (opt.hr_height, opt.hr_width)


# Initialize generator and discriminator
generator = GeneratorRRDB(opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device)
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

if opt.epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models/generator_%d.pth" % opt.epoch))
    discriminator.load_state_dict(torch.load("./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models/discriminator_%d.pth" % opt.epoch))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

dataloader = DataLoader(
    ImageDataset("./drive/MyDrive/datasets/DIV2K/DIV2K_train_HR" , hr_shape=hr_shape),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

# ----------
#  Training
# ----------

for epoch in range(opt.epoch, opt.n_epochs):
  for i, imgs in enumerate(dataloader):

      batches_done = epoch * len(dataloader) + i

      # Configure model input
      imgs_lr = Variable(imgs["lr"].type(Tensor))
      imgs_hr = Variable(imgs["hr"].type(Tensor))

      # Adversarial ground truths
      valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
      fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

      # ------------------
      #  Train Generators
      # ------------------

      optimizer_G.zero_grad()

      # Generate a high resolution image from low resolution input
      gen_hr = generator(imgs_lr)

      # Measure pixel-wise loss against ground truth
      loss_pixel = criterion_pixel(gen_hr, imgs_hr)

      if batches_done < opt.warmup_batches:
          # Warm-up (pixel-wise loss only)
          loss_pixel.backward()
          optimizer_G.step()
          print(
                "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), loss_pixel.item())
          )
          continue

      # Extract validity predictions from discriminator
      pred_real = discriminator(imgs_hr).detach()
      pred_fake = discriminator(gen_hr)

      # Adversarial loss (relativistic average GAN)
      loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

      # Content loss
      gen_features = feature_extractor(gen_hr)
      real_features = feature_extractor(imgs_hr).detach()
      loss_content = criterion_content(gen_features, real_features)

      # Total generator loss
      loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel

      loss_G.backward()
      optimizer_G.step()

      # ---------------------
      #  Train Discriminator
      # ---------------------

      optimizer_D.zero_grad()

      pred_real = discriminator(imgs_hr)
      pred_fake = discriminator(gen_hr.detach())

      # Adversarial loss for real and fake images (relativistic average GAN)
      loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
      loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

      # Total loss
      loss_D = (loss_real + loss_fake) / 2

      loss_D.backward()
      optimizer_D.step()

      # --------------
      #  Log Progress
      # --------------

      print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
            % (
                epoch,
                opt.n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_content.item(),
                loss_GAN.item(),
                loss_pixel.item(),
          )
       )


      if batches_done % opt.sample_interval == 0:
          # Save image grid with upsampled inputs and ESRGAN outputs
          imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
          img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
          save_image(img_grid, "./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/images/training/%d.png" % batches_done, nrow=1, normalize=False)

      if batches_done % opt.checkpoint_interval == 0:
          # Save model checkpoints
          torch.save(generator.state_dict(), "./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models/generator_%d.pth" % epoch)
          print("saved generator %d successfully!" %epoch)
          torch.save(discriminator.state_dict(), "./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models/discriminator_%d.pth" %epoch)
          print("saved discriminator %d successfully!" %epoch)

#Test

In [None]:
!pip install git+https://github.com/olivesgatech/dippykit.git

In [24]:
#Utils

from dippykit.metrics import PSNR as psnr
from dippykit.metrics import SSIM as ssim
import numpy as np
import os
import torch
from collections import OrderedDict

import time
class Timer():

    def __init__(self):
        self.v = time.time()

    def s(self):
        self.v = time.time()

    def t(self):
        return time.time() - self.v


def time_text(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    elif t >= 60:
        return '{:.1f}m'.format(t / 60)
    else:
        return '{:.1f}s'.format(t)
def compute_psnr(im1, im2):
    p = psnr(im1, im2)
    return p


def compute_ssim(im1, im2):
    s = ssim(im1, im2, use_gaussian_window=True, sigma=1.5, auto_downsample=True)[0]
    return s


def shave(im, border):
    border = [border, border]
    im = im[border[0]:-border[0], border[1]:-border[1], ...]
    return im


def modcrop(im, modulo):
    sz = im.shape
    h = np.int32(sz[0] / modulo) * modulo
    w = np.int32(sz[1] / modulo) * modulo
    ims = im[0:h, 0:w, ...]
    return ims


def get_list(path, ext):
    return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)]


def convert_shape(img):
    img = np.transpose((img * 255.0).round(), (1, 2, 0))
    img = np.uint8(np.clip(img, 0, 255))
    return img


def quantize(img):
    return img.clip(0, 255).round().astype(np.uint8)


def tensor2np(tensor, out_type=np.uint8, min_max=(0, 1)):
    tensor = tensor.float().cpu().clamp_(*min_max)
    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0, 1]
    img_np = tensor.numpy()
    img_np = np.transpose(img_np, (1, 2, 0))
    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()

    return img_np.astype(out_type)

def convert2np(tensor):
    return tensor.cpu().mul(255).clamp(0, 255).byte().squeeze().permute(1, 2, 0).numpy()


def adjust_learning_rate(optimizer, epoch, step_size, lr_init, gamma):
    factor = epoch // step_size
    lr = lr_init * (gamma ** factor)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def load_state_dict(path):

    state_dict = torch.load(path)
    new_state_dcit = OrderedDict()
    for k, v in state_dict.items():
        if 'module' in k:
            name = k[7:]
        else:
            name = k
        new_state_dcit[name] = v
    return new_state_dcit

In [None]:
import argparse
import torch
import os
import numpy as np
import skimage.color as sc
import cv2

# Testing settings

parser = argparse.ArgumentParser(description='ESRGAN')
parser.add_argument("--test_hr_folder", type=str, default='./drive/MyDrive/datasets/Urban100/', help='the folder of the target images')
parser.add_argument("--test_lr_folder", type=str, default='./drive/MyDrive/datasets/Urban100/LRbicx4/', help='the folder of the input images')
parser.add_argument("--output_folder", type=str, default='./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/results/Urban100/x4')
parser.add_argument("--checkpoint", type=str, default='./drive/MyDrive/ColabNotebooks/ESRGAN-pytorch/saved_models/generator_199.pth', help='checkpoint folder to use')
parser.add_argument('--cuda', action='store_true', default=True, help='use cuda')
parser.add_argument("--upscale_factor", type=int, default=4, help='upscaling factor')
parser.add_argument("--is_y", action='store_true', default=True, help='evaluate on y channel, if False evaluate on RGB channels')
opt = parser.parse_args(args=[])

#print(opt)

def forward_chop(model, x, shave=10, min_size=60000):
  
  scale = 4 #self.scale[self.idx_scale]
  n_GPUs = 1 #min(self.n_GPUs, 4)
  b, c, h, w = x.size()
  h_half, w_half = h // 2, w // 2
  h_size, w_size = h_half + shave, w_half + shave
  lr_list = [
      x[:, :, 0:h_size, 0:w_size],
      x[:, :, 0:h_size, (w - w_size):w],
      x[:, :, (h - h_size):h, 0:w_size],
      x[:, :, (h - h_size):h, (w - w_size):w]]

  if w_size * h_size < min_size:
    sr_list = []
    for i in range(0, 4, n_GPUs):

      lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
      sr_batch = model(lr_batch)
      sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
  
  else:
    sr_list = [
        forward_chop(model, patch, shave=shave, min_size=min_size) \
        for patch in lr_list
    ]

  h, w = scale * h, scale * w
  h_half, w_half = scale * h_half, scale * w_half
  h_size, w_size = scale * h_size, scale * w_size
  shave *= scale

  output = x.new(b, c, h, w)
  output[:, :, 0:h_half, 0:w_half] \
      = sr_list[0][:, :, 0:h_half, 0:w_half]
  output[:, :, 0:h_half, w_half:w] \
      = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
  output[:, :, h_half:h, 0:w_half] \
      = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
  output[:, :, h_half:h, w_half:w] \
      = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

  return output

cuda = opt.cuda
device = torch.device('cuda' if cuda else 'cpu')

filepath = opt.test_hr_folder

if filepath.split('/')[-2] == 'Set5' or filepath.split('/')[-2] == 'Set14' or filepath.split('/')[-2] == 'BSDS100' or filepath.split('/')[-2] == 'Urban100':
  ext = '.png'
else:
  ext = '.bmp'

filelist = get_list(filepath, ext=ext)
psnr_list = np.zeros(len(filelist))
ssim_list = np.zeros(len(filelist))
time_list = np.zeros(len(filelist))

model = GeneratorRRDB(channels=3, filters=64, num_res_blocks=23).to(device)
model_dict = load_state_dict(opt.checkpoint)
model.load_state_dict(model_dict, strict=False)#True)

i = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

for imname in filelist:

  im_gt = cv2.imread(imname, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]]  # BGR to RGB
  im_gt = modcrop(im_gt, opt.upscale_factor)
  im_l = cv2.imread(opt.test_lr_folder + imname.split('/')[-1].split('.')[0] + ext, cv2.IMREAD_COLOR)[:, :, [2, 1, 0]]  # BGR to RGB    #'x' + str(opt.upscale_factor) +
  
  #print("im_gt:\n", im_gt.shape)
  #print("im_l:\n" , im_l.shape)
  
  if len(im_gt.shape) < 3:

    im_gt = im_gt[..., np.newaxis]
    im_gt = np.concatenate([im_gt] * 3, 2)
    im_l = im_l[..., np.newaxis]
    im_l = np.concatenate([im_l] * 3, 2)

    #print("len(im_gt.shape) < 3 shape:\n", im_gt.shape)
    #print("len(im_lr.shape) < 3 shape:\n" , im_l.shape)


  im_input = im_l / 255.0
  im_input = np.transpose(im_input, (2, 0, 1))
  im_input = im_input[np.newaxis, ...]
  im_input = torch.from_numpy(im_input).float()

  if cuda:
    model = model.to(device)
    im_input = im_input.to(device)

  with torch.no_grad():
    start.record()
    out = forward_chop(model, im_input) #model(im_input)
    end.record()
    torch.cuda.synchronize()
    time_list[i] = start.elapsed_time(end)  # milliseconds

  out_img = tensor2np(out.detach()[0])
  #print("im_sr:\n" , out_img.shape)

  crop_size = opt.upscale_factor
  cropped_sr_img = shave(out_img, crop_size)
  #print("im_sr_crop:\n" , cropped_sr_img.shape)
  cropped_gt_img = shave(im_gt, crop_size)
  #print("im_gt_crop:\n" , cropped_gt_img.shape)

  if opt.is_y is True:
    #print("if opt.is_y is True\n")
    im_label = quantize(sc.rgb2ycbcr(cropped_gt_img)[:, :, 0])
    im_pre = quantize(sc.rgb2ycbcr(cropped_sr_img)[:, :, 0])
    #print("y im_pre:\n" , im_pre.shape)
    #print("y im_label:\n", im_label.shape)
    if(im_label.shape != im_pre.shape):
      #print("im_label shape:\n" , im_label.shape)
      im_pre = cv2.resize((im_pre) , (int(im_label.shape[1]) , int(im_label.shape[0])))
      #print("im_pre resized shape:\n" , im_pre.shape)
  else:
    im_label = cropped_gt_img
    im_pre = cropped_sr_img

  psnr_list[i] = compute_psnr(im_pre, im_label)
  ssim_list[i] = compute_ssim(im_pre, im_label)


  output_folder = os.path.join(opt.output_folder, imname.split('/')[-1].split('.')[0] + 'x' + str(opt.upscale_factor) + '.png')

  if not os.path.exists(opt.output_folder):
    os.makedirs(opt.output_folder)

  cv2.imwrite(output_folder, out_img[:, :, [2, 1, 0]])
  i += 1

print("Mean PSNR: {}, SSIM: {}, TIME: {} ms".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list)))