# SRGAN

In [1]:
import torch
from torch import nn

class ResidualBlock(nn.Module):
    def __init__(self, input_channel=64, output_channel=64, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding=1),
            nn.BatchNorm2d(output_channel),
            nn.PReLU(),
            nn.Conv2d(output_channel, output_channel, kernel_size, stride, padding=1),
            nn.BatchNorm2d(output_channel)
        )

    def forward(self, x0):
        x1 = self.layer(x0)
        return x0 + x1
    
class Generator(nn.Module):
    '''放大倍数为4，有两个sub-pixel convolution layers'''
    def __init__(self):
        super(Generator, self).__init__()
        self.layer_1 = nn.Sequential(
            nn.Conv2d(3, 64, 9, stride=1, padding=4),
            nn.PReLU()
        )
        self.layer_2 = nn.Sequential(
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock()  # 原文中B=16
        )
        self.layer_3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, padding=1),
            nn.BatchNorm2d(64)
        )
        self.layer_4 = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, padding=1),
            nn.PixelShuffle(2),  # (_, C*r^2, H, W) to (_, C, H*r, W*r)  上采样 sub-pixel convolution layers
            nn.PReLU(),
            nn.Conv2d(64, 256, 3, 1, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.layer_5 = nn.Conv2d(64, 3, 9, 1, padding=4)
        
    def forward(self, x):
        x0 = self.layer_1(x)
        x = self.layer_2(x0)
        x = self.layer_3(x)
        x = x + x0  # skip connection
        x = self.layer_4(x)
        x = self.layer_5(x)
        return (torch.tanh(x) + 1) / 2  # 限制到[0,1]
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, 3, 2, padding=1),  # Strided convolutions are used to reduce the image resolution each time the number of features is doubled. 
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, 3, 2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, 3, 2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, 2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.AdaptiveAvgPool2d(1),  # (_, 512, 1, 1)
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

## 损失函数

In [2]:
import torch
from torch import nn
from torchvision import models

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        # define the VGG loss based on the ReLU activation layers of the pre-trained 19 layer VGG network
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        vgg = models.vgg16(pretrained=True)  # 待修改
        vgg = nn.Sequential(*list(vgg.features)[:31]).eval()  # 待修改
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.mse_loss = nn.MSELoss().to(device)
        
    def forward(self, D_out, img_sr, img_hr):  # D_out = D(G(I_LR))
        # Adversarial loss
        # adversarial_loss = torch.sum(-torch.log(D_out))
        adversarial_loss = torch.mean(1 - D_out)  # 修改, 不取log
        # Content Loss
        content_loss = self.mse_loss(self.vgg(img_sr), self.vgg(img_hr))  # 内容损失是预训练VGG中特征图的mse
        # Image Loss  原文没有这项
        image_loss = self.mse_loss(img_sr, img_hr)
        
        loss_SR = image_loss + 6e-3*content_loss + 1e-3*adversarial_loss  # 感知损失
        return loss_SR
    
print(GeneratorLoss())

GeneratorLoss(
  (vgg): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilat

In [3]:
#DIV2K数据预处理
from os.path import join
from os import listdir
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize

def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])


def train_lr_transform(crop_size, scale):
    return Compose([
        ToPILImage(),
        Resize(crop_size // scale, interpolation=Image.Resampling.BICUBIC),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, scale):
        super(TrainDatasetFromFolder, self).__init__()
        self.images_filenames = [join(dataset_dir,x) for x in listdir(dataset_dir)]
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, scale)
    
    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.images_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image
    
    def __len__(self):
        return len(self.images_filenames)
        

In [None]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import h5py
import numpy as np
from torch.utils.data import Dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
netG = Generator().to(device)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator().to(device)
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
g_criterion = GeneratorLoss().to(device)
# d_criterion = nn.BCELoss().to(device)
# lr = 0.001
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.9, 0.999))  # , lr=1e-5
optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.9, 0.999))  #, lr=1e-4

crop_size = 96
scale = 4
dataset_dir = '../dataSets/DIV2K_train_HR'
ds_train = TrainDatasetFromFolder(dataset_dir, crop_size, scale)
dl_train = DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=4)

results = {'d_loss': [], 'g_loss': []}
epochs = 100
epochs_bar = tqdm(range(epochs),total=epochs)
for epoch in epochs_bar:
    running_results = {'d_loss': [], 'g_loss': []}
    netG.train()
    netD.train()
    
    for features, labels in dl_train:
        batch_size = features.shape[0]
        # real_label = torch.ones([batch_size, 1, 1, 1], device=device)
        # fake_label = torch.zeros([batch_size, 1, 1, 1], device=device)
        features = features.to(device)
        labels = labels.to(device)
        I_SR = netG(features).detach()
        # 训练判别器D maximize log(D(I^HR)) + log(1-D(G(I^LR)))
        # loss_d = d_criterion(netD(labels), real_label) + d_criterion(netD(I_SR), fake_label)
        loss_d = 1 - netD(labels).mean() + netD(I_SR).mean()  # 修改，不取log
        optimizerD.zero_grad()
        loss_d.backward()
        optimizerD.step()
        
        # 训练生成器
        I_SR = netG(features)
        loss_g = g_criterion(netD(I_SR).detach(), I_SR, labels)
        optimizerG.zero_grad()
        loss_g.backward()
        optimizerG.step()
        
        running_results['d_loss'].append(loss_d.item())
        running_results['g_loss'].append(loss_g.item())
        
    d_loss = torch.mean(torch.FloatTensor(running_results['d_loss'])).item()
    g_loss = torch.mean(torch.FloatTensor(running_results['g_loss'])).item()
    epochs_bar.set_description(
        desc='d_loss: %.6f  g_loss: %.6f' % (d_loss, g_loss)
    )
    results['d_loss'].append(d_loss)
    results['g_loss'].append(g_loss)
    
plt.plot(range(epochs), results['d_loss'], label='d_loss')
plt.plot(range(epochs), results['g_loss'], label='g_loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
from PIL import Image
from torchvision.transforms import Compose, ToTensor, ToPILImage
scale = 4
test_img = Image.open('img_005_SRF_2_LR.png')
origin = Image.open('img_005_SRF_2_HR.png')
input = ToTensor()(test_img).unsqueeze(0).to(device)
#input = torch.tensor(np.transpose(test_img / 255., (2,0,1)), device=device).unsqueeze(0)  # [B, C, H, W]
netG.eval()
with torch.no_grad():
    output = netG(input)
#output = output.cpu().data.squeeze().permute(1,2,0).numpy()
output = ToPILImage()(output[0].data.cpu())
plt.figure(figsize=(12,8))
plt.subplot(1,3,1)
plt.imshow(test_img)
plt.title('LR')
plt.subplot(1,3,2)
plt.imshow(output)
plt.title('SR')
plt.subplot(1,3,3)
plt.imshow(origin)
plt.title('HR')
#SR = (output * 255.).astype(np.uint8)
#out_Img = pil_image.fromarray(SR)
#output.save('./SR_image.png')

In [None]:
from PIL import Image
from torchvision.transforms import Compose, ToTensor, ToPILImage
scale = 4
test_img = Image.open('butterfly.png')
input = ToTensor()(test_img).unsqueeze(0).to(device)
#input = torch.tensor(np.transpose(test_img / 255., (2,0,1)), device=device).unsqueeze(0)  # [B, C, H, W]
netG.eval()
with torch.no_grad():
    output = netG(input)
#output = output.cpu().data.squeeze().permute(1,2,0).numpy()
output = ToPILImage()(output[0].data.cpu())
plt.figure(figsize=(12,8))
plt.subplot(1,2,1)
plt.imshow(test_img)
plt.title('LR')
plt.subplot(1,2,2)
plt.imshow(output)
plt.title('SR')
#SR = (output * 255.).astype(np.uint8)
#out_Img = pil_image.fromarray(SR)
output.save('./SR_image1.png')

In [7]:
torch.save(netG.state_dict(), './generator.pkl')