In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import grad
import numpy as np
import os
import time

basicChannels = 64
trainer = None
# 生成器
class G(nn.Module):
    def __init__(self):
        super().__init__()

        self.module = nn.Sequential(
            # size 1 x 1
            nn.ConvTranspose2d(100,basicChannels*8,4,1,0,bias=False),
            nn.ReLU(),
            # size 4 x 4
            nn.ConvTranspose2d(basicChannels*8,basicChannels*4,4,2,1,bias=False),
            nn.ReLU(),
            # size 8 x 8
            nn.ConvTranspose2d(basicChannels*4,basicChannels*2,4,2,1,bias=False),
            nn.ReLU(),
            # size 16 x 16
            nn.ConvTranspose2d(basicChannels*2,basicChannels,4,2,1,bias=False),
            nn.ReLU(),
            # size 32 x 32
            # make G more powerful
            nn.Conv2d(basicChannels,basicChannels,3,1,1,bias=False),
            nn.ReLU(),
            # size 32 x 32
            nn.ConvTranspose2d(basicChannels,3,4,2,1,bias=False),
            nn.Tanh()
            # size 64 x 64
        )

    def forward(self,input):
        return self.module(input)

# 判别器
class D(nn.Module):
    def __init__(self):
        super().__init__()

        self.module = nn.Sequential(
            # size 64 x 64
            nn.Conv2d(3,basicChannels,4,2,1,bias=False),
            nn.LeakyReLU(),
            # size 32 x 32
            nn.Conv2d(basicChannels,basicChannels*2,4,2,1,bias=False),
            nn.LeakyReLU(),
            # size 16 x 16
            nn.Conv2d(basicChannels*2,basicChannels*4,4,2,1,bias=False),
            nn.LeakyReLU(),
            # size 8 x 8
            nn.Conv2d(basicChannels*4,1,8,1,0,bias=False),
            # size (batch x 1) x 1 x 1
        )

    def forward(self,input):
        return self.module(input)
# 训练器
class Trainer:
    #定义模型变量
    g=None
    d=None
    # 收集损失
    dLoss = []
    gLoss = []
    # 开始时间
    startTime = time.time()
    #当前训练轮次
    thisEpoch=0
    
    # 若指定modelIndex，则从第modelIndex个模型初始化，否则进行重置初始化
    def __init__(self):
        # 初始化一些目录
        self.genImgsPath = "%s/generatorImgs"%self.ROOT
        self.modelPath = "%s/models"%self.ROOT
        if not os.path.isdir(self.genImgsPath):
            os.makedirs(self.genImgsPath)
        if not os.path.isdir(self.modelPath):
            os.makedirs(self.modelPath)
        # 生成网络的输入噪声
        self.GNoise = torch.Tensor(self.batch_size,100,1,1).cuda()
        # 训练判别器时，GD_alpha*reaImgs+(1-GD_alpha)*genImgs
        self.GD_alpha = torch.Tensor(self.batch_size,1,1,1).cuda()
        # 加载训练数据
        faces = np.load('%s/../Face[7000x3x64x64]_Float32_[-1|1].npy'%self.ROOT)
        faces = torch.from_numpy(faces).float().cuda()
        tensorDataset = torch.utils.data.TensorDataset(faces)
        self.dataLoader = torch.utils.data.DataLoader(tensorDataset,batch_size=self.batch_size,shuffle=True,drop_last=True)
        
    def __assertInit(self,modelIndex=None):
        #如果指定了模型编号，则从已有模型初始化，否则进行重置初始化
        if None==modelIndex and None==self.g and None==self.d:
            print("------------ Init ------------")
            print("init original model...")
            self.__initGAN()
            print("init original model completed.")
        elif type(modelIndex)==int:
            print("------------ Init ------------")
            print("init from %dth model..."%modelIndex)
            self.__initFromModel(modelIndex)
            print("init from %dth model completed."%modelIndex)
            
    
    # 初始化网络，定义优化器和损失函数
    def __initGAN(self):
        self.g=G().cuda()
        self.d=D().cuda()
        self.g_optim = optim.Adam(self.g.parameters())
        self.d_optim = optim.Adam(self.d.parameters())
        
    # 从已有模型开始训练
    def __initFromModel(self,modelIndex):
        self.g = torch.load("%s/g_epoch%d.pkl"%(self.modelPath,modelIndex))
        self.d = torch.load("%s/d_epoch%d.pkl"%(self.modelPath,modelIndex))
        self.g_optim = optim.Adam(self.g.parameters())
        self.d_optim = optim.Adam(self.d.parameters())
        
    def __trainDiscriminator(self,realImgs):
        for p in self.g.parameters():
            p.requires_grad=False
        for p in self.d.parameters():
            p.requires_grad=True
        # 生成器的输入噪声
        self.GNoise.normal_()  #高斯分布
        genImgs = self.g(self.GNoise)
        d_out_gen = self.d(genImgs)
        d_out_real = self.d(realImgs)
        loss = d_out_gen.mean() - d_out_real.mean()
        # 融合真假样本
        self.GD_alpha.uniform_()  # 均匀分布
        mergeImgs = self.GD_alpha*realImgs + (1-self.GD_alpha)*genImgs
        mergeImgs.requires_grad=True
        d_out_merge = self.d(mergeImgs)
        merge_grad = grad(outputs=d_out_merge.sum(),inputs=mergeImgs,create_graph=True)[0]
        grad_penalty = ((merge_grad.view(merge_grad.size(0), -1).norm(2, dim=1) - 1)**2).mean()
        # 最终的损失
        loss += self.GP_lambda * grad_penalty
        # 更新判别器的权重
        self.d_optim.zero_grad()
        loss.backward()
        self.d_optim.step()
        # 返回损失
        return loss.item()

    def __trainGenerator(self):
        for p in self.g.parameters():
            p.requires_grad=True
        for p in self.d.parameters():
            p.requires_grad=False
        # 生成图片
        self.GNoise.normal_()
        genImgs = self.g(self.GNoise)
        d_out = self.d(genImgs)
        # 最终损失
        loss = -d_out.mean()
        # 更新生成器权重
        self.g_optim.zero_grad()
        loss.backward()
        self.g_optim.step()
        # 返回损失和一张生成图
        return loss.item(),genImgs[0]

    def __printLossAndSaveImg(self,epoch,genImg):
        # 打印日志
        log = "epoch= %d ,times=%.2f ,dLoss = %.4f ,gLoss = %.4f"%(epoch,time.time() - self.startTime,sum(self.dLoss)/len(self.dLoss),sum(self.gLoss)/len(self.gLoss))
        print(log)
        self.dLoss = []
        self.gLoss = []
        # 将日志写入文件
        with open("%s/logs.txt"%self.ROOT,"a") as f:
            f.write(log+"\n")
        # 保存生成图片到磁盘
        utils.save_image(genImg,"%s/%d.jpg"%(self.genImgsPath,epoch),normalize=True,range=(-1,1))
        self.startTime = time.time()
    
    #开始训练
    def start(self,modelIndex=None,offset=None):
        #判断如何初始化模型
        self.__assertInit(modelIndex)
        print("------------ Training ------------")
        print("training process start...")
        #如果指定模型，则将偏移量与模型对齐
        if modelIndex!=None:
            offset = modelIndex
        if offset==None:
            offset = self.thisEpoch
        print("training offset %d"%offset)
        print("----------------------------------")
        #开始循环训练
        for epoch in range(self.epochs):
            # ======从中间轮开始训练=======
            self.thisEpoch=epoch=epoch+1+offset
            # ========== End ===========
            for _,realImgs in enumerate(self.dataLoader):
                # 训练判别器
                d_loss = self.__trainDiscriminator(realImgs[0])
                self.dLoss.append(d_loss)
                # 训练生成器
                g_loss,genImg = self.__trainGenerator()
                self.gLoss.append(g_loss)
            # 每训练5轮，输出损失情况，保存一份生成的图片
            if epoch%5 == 0:
                self.__printLossAndSaveImg(epoch,genImg)
            # 每训练50轮，保存一次模型 
            if epoch%50 == 0:
                torch.save(self.g,"%s/g_epoch%d.pkl"%(self.modelPath,epoch))
                torch.save(self.d,"%s/d_epoch%d.pkl"%(self.modelPath,epoch))
        print("training process completed.")
                
    #=========配置一些的参数============
    #根目录
    ROOT="./"
    # 训练轮数
    epochs = 30000
    # 批大小
    batch_size = 64
    # 对梯度惩罚的系数
    GP_lambda = 10 

In [None]:
if None==trainer:
    trainer = Trainer()
    trainer.start()
else:
    print("trainer is not None.")