<a href="https://colab.research.google.com/github/zhangzhihengcn/SRGAN-Pytorch/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models.vgg import vgg16
from tqdm import tqdm
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# 图像处理操作，包括随机裁剪，转换张量
transform = transforms.Compose([transforms.RandomCrop(96),
                                transforms.ToTensor()])

path = './AnimeTest/'


class PreprocessDataset(Dataset):
    """预处理数据集类"""

    def __init__(self, imgPath=path, transforms=transform, ex=10):
        """初始化预处理数据集类"""
        self.transforms = transform

        for _, _, files in os.walk(imgPath):
            self.imgs = [imgPath + file for file in files] * ex

        np.random.shuffle(self.imgs)  # 随机打乱

    def __len__(self):
        """获取数据长度"""
        return len(self.imgs)

    def __getitem__(self, index):
        """获取数据"""
        tempImg = self.imgs[index]
        tempImg = Image.open(tempImg)

        sourceImg = self.transforms(tempImg)  # 对原始图像进行处理
        cropImg = torch.nn.MaxPool2d(4, stride=4)(sourceImg)
        return cropImg, sourceImg


path = './AnimeTest/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH = 32
EPOCHS = 100

# 构建数据集
processDataset = PreprocessDataset(imgPath=path)
trainData = DataLoader(processDataset, batch_size=BATCH)

# 构造迭代器并取出其中一个样本
dataiter = iter(trainData)
testImgs, _ = dataiter.next()

testImgs = testImgs.to(device)  # testImgs的用处是为了可视化生成对抗的结果


class ResBlock(nn.Module):
    """残差模块"""

    def __init__(self, inChannals, outChannals):
        """初始化残差模块"""
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inChannals, outChannals, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(outChannals)
        self.conv2 = nn.Conv2d(outChannals, outChannals, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(outChannals)
        self.conv3 = nn.Conv2d(outChannals, outChannals, kernel_size=1, bias=False)
        self.relu = nn.PReLU()

    def forward(self, x):
        """前向传播过程"""
        resudial = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(x)
        out += resudial
        out = self.relu(out)
        return out


class Generator(nn.Module):
    """生成模型(4x)"""

    def __init__(self):
        """初始化模型配置"""
        super(Generator, self).__init__()
        # 卷积模块1
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4, padding_mode='reflect', stride=1)
        self.relu = nn.PReLU()
        # 残差模块
        self.resBlock = self._makeLayer_(ResBlock, 64, 64, 5)
        # 卷积模块2
        self.conv2 = nn.Conv2d(64, 64, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.PReLU()

        # 子像素卷积
        self.convPos1 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=2, padding_mode='reflect')
        self.pixelShuffler1 = nn.PixelShuffle(2)
        self.reluPos1 = nn.PReLU()

        self.convPos2 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
        self.pixelShuffler2 = nn.PixelShuffle(2)
        self.reluPos2 = nn.PReLU()

        self.finConv = nn.Conv2d(64, 3, kernel_size=9, stride=1)

    def _makeLayer_(self, block, inChannals, outChannals, blocks):
        """构建残差层"""
        layers = []
        layers.append(block(inChannals, outChannals))

        for i in range(1, blocks):
            layers.append(block(outChannals, outChannals))

        return nn.Sequential(*layers)

    def forward(self, x):
        """前向传播过程"""
        x = self.conv1(x)
        x = self.relu(x)
        residual = x
        out = self.resBlock(x)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.convPos1(out)
        out = self.pixelShuffler1(out)
        out = self.reluPos1(out)
        out = self.convPos2(out)
        out = self.pixelShuffler2(out)
        out = self.reluPos2(out)
        out = self.finConv(out)

        return out


class ConvBlock(nn.Module):
    """残差模块"""

    def __init__(self, inChannals, outChannals, stride=1):
        """初始化残差模块"""
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(inChannals, outChannals, kernel_size=3, stride=stride, padding=1, padding_mode='reflect',
                              bias=False)
        self.bn = nn.BatchNorm2d(outChannals)
        self.relu = nn.LeakyReLU()

    def forward(self, x):
        """前向传播过程"""
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
        self.relu1 = nn.LeakyReLU()

        self.convBlock1 = ConvBlock(64, 64, stride=2)
        self.convBlock2 = ConvBlock(64, 128, stride=1)
        self.convBlock3 = ConvBlock(128, 128, stride=2)
        self.convBlock4 = ConvBlock(128, 256, stride=1)
        self.convBlock5 = ConvBlock(256, 256, stride=2)
        self.convBlock6 = ConvBlock(256, 512, stride=1)
        self.convBlock7 = ConvBlock(512, 512, stride=2)

        self.avePool = nn.AdaptiveAvgPool2d(1)
        self.conv2 = nn.Conv2d(512, 1024, kernel_size=1)
        self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv2d(1024, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)

        x = self.convBlock1(x)
        x = self.convBlock2(x)
        x = self.convBlock3(x)
        x = self.convBlock4(x)
        x = self.convBlock5(x)
        x = self.convBlock6(x)
        x = self.convBlock7(x)

        x = self.avePool(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.sigmoid(x)

        return x


# 构造模型
netD = Discriminator()
netG = Generator()
netD.to(device)
netG.to(device)

# 构造迭代器
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

# 构造损失函数
lossF = nn.MSELoss().to(device)

# 构造VGG损失中的网络模型
vgg = vgg16(pretrained=True).to(device)
lossNetwork = nn.Sequential(*list(vgg.features)[:31]).eval()
for param in lossNetwork.parameters():
    param.requires_grad = False  # 让VGG停止学习

for epoch in range(EPOCHS):
    netD.train()
    netG.train()
    processBar = tqdm(enumerate(trainData, 1))

    for i, (cropImg, sourceImg) in processBar:
        cropImg, sourceImg = cropImg.to(device), sourceImg.to(device)

        fakeImg = netG(cropImg).to(device)

        # 迭代辨别器网络
        netD.zero_grad()
        realOut = netD(sourceImg).mean()
        fakeOut = netD(fakeImg).mean()
        dLoss = 1 - realOut + fakeOut
        dLoss.backward(retain_graph=True)

        # 迭代生成器网络
        netG.zero_grad()
        gLossSR = lossF(fakeImg, sourceImg)
        gLossGAN = 0.001 * torch.mean(1 - fakeOut)
        gLossVGG = 0.006 * lossF(lossNetwork(fakeImg), lossNetwork(sourceImg))
        gLoss = gLossSR + gLossGAN + gLossVGG
        gLoss.backward()

        optimizerD.step()
        optimizerG.step()

        # 数据可视化
        processBar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            epoch, EPOCHS, dLoss.item(), gLoss.item(), realOut.item(), fakeOut.item()))

    # 将文件输出到目录中
    with torch.no_grad():
        fig = plt.figure(figsize=(10, 10))
        plt.axis("off")
        fakeImgs = netG(testImgs).detach().cpu()
        plt.imshow(np.transpose(vutils.make_grid(fakeImgs, padding=2, normalize=True), (1, 2, 0)), animated=True)
        plt.savefig('./Img/Result_epoch % 05d.jpg' % epoch, bbox_inches='tight', pad_inches=0)
        print('[INFO] Image saved successfully!')

    # 保存模型路径文件
    torch.save(netG.state_dict(), 'model/netG_epoch_%d_%d.pth' % (4, epoch))
    torch.save(netD.state_dict(), 'model/netD_epoch_%d_%d.pth' % (4, epoch))


AttributeError: ignored