In [1]:
# https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/esrgan

import os
import glob
import random
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## dataset

In [2]:
# pre-trained pytorch model의 normalization parameters

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def denormalize(tensors):
  """ Denormalizes image tensors using mean and std """
  for c in range(3):
    tensors[:, c].mul_(std[c]).add_(mean[c])
  return torch.clamp(tensors, 0, 255)


class ImageDataset(torch.utils.data.Dataset):
  def __init__(self, root, hr_shape):
    hr_height, hr_width = hr_shape
    # Transforms for low resolution images and high resolution images
    self.lr_transform = transforms.Compose(
        [transforms.Resize((hr_height // 4, hr_height // 4),
                           transforms.InterpolationMode("bicubic")),
         transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    
    self.hr_transform = transforms.Compose(
        [transforms.Resize((hr_height, hr_height),
                           transforms.InterpolationMode("bicubic")),
         transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    self.files = sorted(glob.glob(root + "/*.jpg"))

  def __getitem__(self, index):
      img = Image.open(self.files[index % len(self.files)])
      img_lr = self.lr_transform(img)
      img_hr = self.hr_transform(img)

      return {"lr": img_lr, "hr": img_hr}

  def __len__(self):
      return len(self.files)

## model

In [3]:
# Feature Extractor

class Feature_Extractor(nn.Module):
  def __init__(self):
    super(Feature_Extractor, self).__init__()
    vgg_model = vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
    self.vgg19_54 = nn.Sequential(*list(vgg_model.features.children())[:35])

  def forward(self, img):
    return self.vgg19_54(img)

In [4]:
# Dense Residual Block & RRDB

class Dense_Residual_Block(nn.Module):
  def __init__(self, filters, res_scale=0.2):
    super(Dense_Residual_Block, self).__init__()
    self.res_scale = res_scale

    def block(in_features, non_linearity=True):
      layers = [nn.Conv2d(in_features, filters,3,1,1,bias=True)]
      if non_linearity:
        layers += [nn.LeakyReLU()]
      return nn.Sequential(*layers)

    self.b1 = block(1*filters)
    self.b2 = block(2*filters)
    self.b3 = block(3*filters)
    self.b4 = block(4*filters)
    self.b5 = block(5*filters)
    self.blocks = [self.b1,self.b2,self.b3,self.b4,self.b5]
  
  def forward(self, x):
    inputs = x
    for b in self.blocks:
      out = b(inputs)
      inputs = torch.cat([inputs, out], dim=1)
    return out.mul(self.res_scale) + x

class RRDB(nn.Module):
  def __init__(self, filters, res_scale=0.2):
    super(RRDB, self).__init__()
    self.res_scale = res_scale
    self.dense_blocks = nn.Sequential(Dense_Residual_Block(filters),
                                    Dense_Residual_Block(filters),
                                    Dense_Residual_Block(filters))
  def forward(self, x):
    return self.dense_blocks(x).mul(self.res_scale)

In [5]:
# Generator & Discriminator

class GeneratorRRDB(nn.Module):
  def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
    super(GeneratorRRDB, self).__init__()

    # First layer
    self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
    # Residual blocks
    self.res_blocks = nn.Sequential(*[RRDB(filters) for _ in range(num_res_blocks)])
    # Second conv layer post residual blocks
    self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
    # Upsampling layers
    upsample_layers = []
    for _ in range(num_upsample):
        upsample_layers += [
            nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.PixelShuffle(upscale_factor=2)
        ]
    self.upsampling = nn.Sequential(*upsample_layers)
    # Final output block
    self.conv3 = nn.Sequential(
        nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(),
        nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1)
    )

  def forward(self, x):
      out1 = self.conv1(x)
      out = self.res_blocks(out1)
      out2 = self.conv2(out)
      out = torch.add(out1, out2)
      out = self.upsampling(out)
      out = self.conv3(out)
      return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

## training

In [6]:
# parameters

epoch = 0 #epoch to start training from
n_epochs = 200 #number of epochs of training
dataset_name = "img_align_celeba" #name of the dataset
batch_size = 64 #size of the batches
lr = 0.0002 #adam: learning rate
b1 = 0.9 #adam: decay of first order momentum of gradient
b2 = 0.999 #adam: decay of first order momentum of gradient
decay_epoch = 100 #epoch from which to start lr decay
n_cpu = 2 #number of cpu threads to use during batch generation
hr_height = 256 #high res. image height
hr_width = 256 #high res. image width
channels = 3 #number of image channels
sample_interval = 100 #interval between saving image samples
checkpoint_interval = 5000 #batch interval between model checkpoints
residual_blocks = 23 #number of residual blocks in the generator
warmup_batches = 500 #number of batches with pixel-wise loss only
lambda_adv = 5e-3 #adversarial loss weight
lambda_pixel = 1e-2 #pixel-wise loss weight

In [7]:
# prepare training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hr_shape = (hr_height, hr_width)

# Initialize generator and discriminator
generator = GeneratorRRDB(channels, filters=64, num_res_blocks=residual_blocks).to(device)
discriminator = Discriminator(input_shape=(channels, *hr_shape)).to(device)
feature_extractor = Feature_Extractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = nn.BCEWithLogitsLoss().to(device)
criterion_content = nn.L1Loss().to(device)
criterion_pixel = nn.L1Loss().to(device)

# if epoch != 0:
#     # Load pretrained models
#     generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % epoch))
#     discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % epoch))

# optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [8]:
root = '/content/drive/MyDrive/Colab Notebooks/Datasets/CelebA/img_align_celeba'

dataloader = DataLoader(
    ImageDataset(root, hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu)

ValueError: ignored

In [None]:
# imgs = next(iter(dataloader))
# imgs_lr = Variable(imgs["lr"].type(Tensor))
# gen_hr = generator(imgs_lr)

# save_path = '/content/drive/MyDrive/Colab Notebooks/코드 이론/model_save/ESRGAN'
# imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
# img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))

# save_image(img_grid, save_path+"/images/test.png", nrow=1, normalize=False)

In [None]:
# training

save_path = '/content/drive/MyDrive/Colab Notebooks/코드 이론/model_save/ESRGAN'

for epoch in range(epoch, n_epochs):
  for i, imgs in enumerate(dataloader):

    batches_done = epoch * len(dataloader) + i

    # model input
    imgs_lr = Variable(imgs["lr"].type(Tensor))
    imgs_hr = Variable(imgs["hr"].type(Tensor))

    # Adversarial ground truths
    valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
    fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

    # ------------------
    #  Train Generators
    # ------------------

    optimizer_G.zero_grad()

    # Generate a high resolution image from low resolution input
    gen_hr = generator(imgs_lr)

    # Measure pixel-wise loss against ground truth
    loss_pixel = criterion_pixel(gen_hr, imgs_hr)

    if batches_done < warmup_batches:
        # Warm-up (pixel-wise loss only)
        loss_pixel.backward()
        optimizer_G.step()
        print(
            "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
            % (epoch, n_epochs, i, len(dataloader), loss_pixel.item())
        )
        continue

    # Extract validity predictions from discriminator
    pred_real = discriminator(imgs_hr).detach()
    pred_fake = discriminator(gen_hr)

    # Adversarial loss (relativistic average GAN)
    loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

    # Content loss
    gen_features = feature_extractor(gen_hr)
    real_features = feature_extractor(imgs_hr).detach()
    loss_content = criterion_content(gen_features, real_features)

    # Total generator loss
    loss_G = loss_content + lambda_adv * loss_GAN + lambda_pixel * loss_pixel

    loss_G.backward()
    optimizer_G.step()

    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    pred_real = discriminator(imgs_hr)
    pred_fake = discriminator(gen_hr.detach())

    # Adversarial loss for real and fake images (relativistic average GAN)
    loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
    loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

    # Total loss
    loss_D = (loss_real + loss_fake) / 2

    loss_D.backward()
    optimizer_D.step()

    # --------------
    #  Log Progress
    # --------------
    if i%200==0:
      print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
        % (
            epoch,
            n_epochs,
            i,
            len(dataloader),
            loss_D.item(),
            loss_G.item(),
            loss_content.item(),
            loss_GAN.item(),
            loss_pixel.item(),
        )
    )
    
    if batches_done % sample_interval == 0:
        # Save image grid with upsampled inputs and ESRGAN outputs
        imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
        img_grid = denormalize(torch.cat((imgs_lr, gen_hr,imgs_hr), -1))
        save_image(img_grid, save_path+"/images/%d.png" % batches_done, nrow=1, normalize=False)

    if batches_done % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), save_path+"/generator_%d.pth" % epoch)
        torch.save(discriminator.state_dict(), save_path+"/discriminator_%d.pth" %epoch)

## Note

https://arxiv.org/pdf/1902.06068.pdf  

생각해볼 것들
1. performance evaluation
 - PSNR, SSIM, LPIPS 등이 있는데 다른 measurement를 만들자
2. network design & learning strategies
 - 다른 모델들 mix or objective 추가
3. upsampling methods
 - cnn 기반의 sota들은 보통 upsampling으로 bicubic, deconv, sub-pixel 를 이용 다른걸 찾거나 만들어보기
4. unsupervised SR
 - 이건 뭘해야할지 잘 모르겠음... ill-posed 때문에 supervised를 하는데 비지도로 바꾸려면 새로운 방식이 필요할듯
