In [None]:
import os
import cv2
import math
import torch,random
import torch.nn as nn
import numpy as np
from functools import partial
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch import optim
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torchvision
#用于半精度计算
from torch.cuda.amp import autocast as autocast
#用于ada
from diff_augment import DiffAugment

In [None]:
#设置一些默认参数
dataroot = 'GT'
label = 'pre'
batch_size = 32
image_size = 256
ngf = 512  #生成器最终特征大小(生成器的特征容量参数)
ndf = 512  #判别器输入特征大小(判别器的特征容量参数)
num_epochs = 501  #这个数据要在测试后，正式形成，以便于节约时间
lrG = 0.0005   #优化器学习率
lrD = 0.0001   #优化器学习率
nc = 1        #图像通道数
ngpu = 1      #gpu数量
workers = 0   #在windows中workers不为0可能会报错
nz = 100      #输入向量大小
beta1 = 0.5      #原来0.5
beta2 = 0.999     #原来0.999
capacity_G =  4     #4,3,2,1  这两个参数比较重要，既要节约时间，也要让网络足够稳定，需测试确认
capacity_D =  2    #4,3,2,1
sa_map_G = image_size//2
sa_map_D = image_size//2
dpout = True   #是否局部失活，增强泛化性
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
#制作数据集
class MyData(Dataset):
    def __init__(self,root_dir,label):
        self.root_dir = root_dir
        self.label = label
        self.path = os.path.join(self.root_dir,self.label)
        self.img_path = os.listdir(self.path)
    def __getitem__(self, idx):         #一个特殊的函数，__getitem__代表着这个数据集中存储的数据内容的反馈，因此后面可以用dataset[0]调用这里的返回值
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path,img_name)
        img = cv2.imread(img_item_path)
        # 这个地方可以放入transforms对图像进行tensor化和各种处理,由于G网络中采用tanh激活，因此要将值转换到（-1,1）区间，采用totensor+normalize两步
        trans = transforms.Compose([transforms.ToTensor(),transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Grayscale(num_output_channels=1)])
        #trans = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(num_output_channels=1)])
        img_tensor = trans(img)
        label = self.label
        return img_tensor, label
    def __len__(self):
        return len(self.img_path)
dataset = MyData(dataroot,label) #数据集创建成功，类型为np（HW-C）
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

In [None]:
writer = SummaryWriter('logs')
step = 0
for img in dataloader:
    data,target = img
    writer.add_images(tag='GPR_data', img_tensor=data, global_step=step, dataformats='NCHW')
    step = step + 1

In [None]:
## 查看数据集,用plt，但是经过测试plt无法成像tensor数据,因此要转换数据类型
def plot_images(img , grid_size = 3):
    fig = plt.figure(figsize=(8,8))
    columns = rows = grid_size
    plt.title('Training Images')
    for i in range(1, columns*rows+1):
        plt.axis('off')
        fig.add_subplot(rows, columns, i)
        plt.imshow(dataset[i][0].numpy().squeeze(),cmap='Greys_r')
plot_images(dataset) #执行绘图函数

In [None]:
#创建Self-Attention Module，谱归一化
class Xiong_Self_Attn(nn.Module):
    #适用于图比较大的时候
    """ Self attention Layer"""
    def __init__(self, inChannels, k=4, sn = True):
        super(Xiong_Self_Attn,self).__init__()
        embedding_channels = inChannels // k  # C_bar
        self.sn = sn
        if self.sn:
            self.key      = nn.utils.spectral_norm(nn.Conv2d(inChannels, embedding_channels, 1))
            self.query    = nn.utils.spectral_norm(nn.Conv2d(inChannels, embedding_channels, 1))
            self.value    = nn.utils.spectral_norm(nn.Conv2d(inChannels, embedding_channels*2, 1))
            self.reprojection = nn.utils.spectral_norm(nn.Conv2d(embedding_channels*2, inChannels, 1))
        else:
            self.key      = nn.Conv2d(inChannels, embedding_channels, 1)
            self.query    = nn.Conv2d(inChannels, embedding_channels, 1)
            self.value    = nn.Conv2d(inChannels, embedding_channels*2, 1)
            self.reprojection = nn.Conv2d(embedding_channels*2, inChannels, 1)
        self.gamma    = nn.Parameter(torch.tensor(0.0), requires_grad=True)      
        

#        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs:
                x: input feature map [Batch, Channel, Height, Width]
            returns:
                out: self attention value + input feature
                attention: [Batch, Channel, Height, Width]
        """
        batchsize, C, H, W = x.size()
        N = H * W                                       # Number of features
        f_x = self.key(x).view(batchsize,   -1, N)      # Keys                  [B, C_bar, N]
        g_x = self.query(x).view(batchsize, -1, N)      # Queries               [B, C_bar, N]
        h_x = self.value(x).view(batchsize, -1, N)      # Values                [B, C_bar, N]
        # 对proj_query行执行softmax，对proj_key进行列softmax
        g_x = F.softmax(g_x,dim=1)
        f_x = F.softmax(f_x,dim=2)
        # 计算两次矩阵乘法
        G =  torch.bmm(f_x,h_x.permute(0,2,1))#B*dk*dv 
        v =  torch.bmm(G.permute(0,2,1),g_x )
        v = v.view(batchsize, -1, H, W)
        o = self.reprojection(v)
        out = self.gamma * o + x
        
        return out

In [None]:
#创建自适应的网络，保障G和D的网络通过参数轻松设定

# 创建生成网络G
#第1步，定义GBlock 
class GBlock(nn.Module):
    def __init__(self, in_chan, out_chan, which_bn = nn.BatchNorm2d, activation = nn.ReLU(True), sn = True):
        #这里之所以要将各层进行单独命名，是为了后面可能存在的改变各层结构的计划
        #默认执行sn
        super(GBlock, self).__init__()
        
        #接收各初始化输入，成为属性（属性表示类除初始化外的函数需要使用）
        self.which_bn = which_bn
        self.activation = activation
        
    #定义GBlock中需要用到的各个层
        #逆卷积层
        if sn:
            self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(in_channels=in_chan, out_channels=out_chan, 
                                                                  kernel_size=4, stride=2, padding=1, bias=False))
        else:
            self.conv = nn.ConvTranspose2d(in_channels=in_chan, out_channels=out_chan, 
                                           kernel_size=4, stride=2, padding=1, bias=False)
            
        #bn层
        self.bn = self.which_bn(out_chan)
                   
    def  forward(self,x):        
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

In [None]:
#第2步，设计生成G网络，主要涉及到：（1）初始网络，（2）注意力网络，（3）最后一层结果
class Generator(nn.Module):
    def __init__(self, image_size = image_size, latent_dim = nz, last_dim = nc , network_capacity = 4, attn_maps = 256, fmap_max = 512*8, 
                 snn = True,which_bn = nn.BatchNorm2d, activation = nn.ReLU(True), ngpu = ngpu):   
        #网络容量和fmap_max共同限制channels
        super(Generator,self).__init__()
        #接收各初始化输入，成为属性
        self.ngpu = ngpu
        self.image_size = image_size
        self.latent_dim = latent_dim
        self.last_dim = last_dim
        self.snn = snn
        self.activation = activation
        self.which_bn = which_bn
        # 定义一些需要的属性        
        self.num_layers = int(math.log2(image_size) - 3)  
        #需要用到GeneratorBlock的层数，对于生成512*512图来说，需要8,16,32,64,128,256这几层采用GBlock，而1-4,256-512则采用nn.ConvTranspose2d
        #最开始的层和最后一层都用简单方法生成,即:1-4,256-512则采用nn.ConvTranspose2d        
        filters = [network_capacity * (2 ** (i + 4)) for i in range(self.num_layers)][::-1]#最大的filters = network_capacity*image_size=16*512=4096
        set_fmap_max = partial(min, fmap_max)  #最大的filters大小限制，这个原始设置，限定了最大就只有512，所以后面就会是512,512,512,256,128，等等。
        filters = list(map(set_fmap_max, filters))  #将所有的大于最大值的都编程最大值，这其实是对通道的一种限制
        init_channels = filters[0]*2                  #意味着，4*4的通道数和8*8的通道数一致（这和DCGAN有差异）
        filters = [init_channels, *filters]         
        #共计7层，将最开始层的大小设计为最大通道数
        
        in_out_pairs = zip(filters[:-1], filters[1:])  #错位1个，形成output和input通道数的元组
        
        self.attn_maps = attn_maps
        self.attn_layers = int(math.log2(attn_maps) - 3)     #确定SA模块位于一系列残差模块中的位置
        
    #网络设计部分
        #初始层，采用nn.ConvTranspose2d
        fist = []
        if self.snn:
            fist.append(nn.utils.spectral_norm(nn.ConvTranspose2d(in_channels= latent_dim,out_channels= filters[0],
                                                                  kernel_size= 4,stride= 1,padding = 0,bias= False)))
        else:
            fist.append(nn.ConvTranspose2d(in_channels= latent_dim,out_channels= filters[0],
                                           kernel_size= 4,stride= 1,padding = 0,bias= False))
        fist.append(self.which_bn(filters[0]))
        fist.append(self.activation)
        self.fist = nn.Sequential(*fist)
        
        #GBlock+SA模块
        self.blocks1 = nn.ModuleList([])
        self.blocks2 = nn.ModuleList([])

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            
            if self.attn_layers == ind:             #定义attn层
                self.attn = Xiong_Self_Attn(out_chan, sn = self.snn)
            
            block = GBlock( in_chan, out_chan , sn = self.snn, which_bn = which_bn ,activation = activation)
            
            if ind > self.attn_layers:      #大于等于SA层则进入blocks2，最极端的情况是不会存在blocks2的
                self.blocks2.append(block)
            else:
                self.blocks1.append(block)   #小于等于SA层则进入blocks1

        
        #定义最后一层        
        last = []
        if self.snn:
            last.append(nn.utils.spectral_norm(nn.ConvTranspose2d(in_channels= filters[-1], out_channels= self.last_dim, 
                                                                  kernel_size=4, stride=2, padding=1,bias=False)))
        else:
            last.append(nn.ConvTranspose2d(in_channels= filters[-1], out_channels= self.last_dim, 
                                           kernel_size=4, stride=2, padding=1,bias=False))
        last.append(nn.Tanh())
        self.last = nn.Sequential(*last)
        
    def forward(self,x):
        
        x = self.fist(x)    #初始层          
                
        #一层一层调用 
        for m in self.blocks1:     #进行blocks1
            x = m(x)
            
        if self.attn_maps  <  self.image_size:     #当输入512或以上时，就是移除SA模块
            x = self.attn(x)                      #判断是否采用SA模块
        
        if self.attn_maps < self.image_size/2:     # 也就是说，放在256之前，会存在blocks2
            for n in self.blocks2:
                x = n(x)    
                
        out = self.last(x)
        
        return out

In [None]:
# 创建判别网络D
#第1步，定义DBlock 
class DBlock(nn.Module):
    def __init__(self, in_chan, out_chan, which_bn = nn.BatchNorm2d, activation = nn.LeakyReLU(0.2, inplace=True), sn = True, dpout = False ):
        #这里之所以要将各层进行单独命名，是为了后面可能存在的改变各层结构的计划
        #默认执行sn
        super(DBlock, self).__init__()
        
        #接收各初始化输入，成为属性（属性表示类除初始化外的函数需要使用）
        self.which_bn = which_bn
        self.activation = activation
        self.dpout = dpout
        if self.dpout:
          self.dpout_layer = nn.Dropout2d(p=0.2)
    #定义DBlock中需要用到的各个层，DBlock和GBlock的唯一区别就是nn.Conv2d和nn.ConvTranspose2d
        #卷积层
        if sn:
            self.conv = nn.utils.spectral_norm(nn.Conv2d(in_channels=in_chan, out_channels=out_chan, 
                                                         kernel_size=4, stride=2, padding=1, bias=False))
        else:
            self.conv = nn.Conv2d(in_channels=in_chan, out_channels=out_chan, 
                                  kernel_size=4, stride=2, padding=1, bias=False)
            
        #bn层
        self.bn = self.which_bn(out_chan)
                   
    def  forward(self,x):        
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        if self.dpout:
          x = self.dpout_layer(x)
        return x

In [None]:
# 第2步，创建判别网络D
class Discriminator(nn.Module):
    def __init__(self, image_size = image_size, first_dim = nc, network_capacity = 4, attn_maps = 256,  fmap_max = 512*8,
                 which_bn = nn.BatchNorm2d, activation = nn.LeakyReLU(0.2, inplace=True), snn = True,ngpu = ngpu,dpout = False):
        super(Discriminator,self).__init__()
        #接收各初始化输入，成为属性
        self.ngpu = ngpu
        self.image_size = image_size
        self.num_layers = int(math.log2(image_size) - 3) 
        self.first_dim = first_dim
        self.attn_maps = attn_maps
        self.snn = snn
        self.activation = activation
        self.which_bn = which_bn
        self.dpout = dpout
        #需要用到DBlock的层数，对于生成512*512图来说，需要128，64,32,16,8,4这几层采用DBlock
        #最开始的层和最后一层都用简单方法生成即可
        
        filters = [network_capacity * (2 ** (i + 4)) for i in range(self.num_layers)] #最大的filters = network_capacity*image_size
        set_fmap_max = partial(min, fmap_max)       #最大的filters大小限制，这个原始设置，限定了最大就只有512，所以后面就会是512,512,512,256,128，等等。
        filters = list(map(set_fmap_max, filters))  #将所有的大于最大值的都编程最大值，这其实是对通道的一种限制
        last_channels = filters[-1]*2
        filters = [ *filters,last_channels]         
        #共计7层，将最开始层的大小设计为最大通道数，第一通道和第二通道的channel数是一样的

        in_out_pairs = zip(filters[:-1], filters[1:])  #错位1个，形成output和input通道数的元组
                
        #网络设计部分
        
        #定义第1层 ,512-256，相当于G的最后一层
        #初始层，采用nn.Conv2d
        first = []
        if self.snn:
            first.append(nn.utils.spectral_norm(nn.Conv2d(in_channels= first_dim,out_channels= filters[0],
                                                         kernel_size= 4,stride= 2,padding = 1,bias= False)))
        else:
            first.append(nn.Conv2d(in_channels= first_dim,out_channels= filters[0],
                                  kernel_size= 4,stride= 2,padding = 1,bias= False))
        first.append(self.activation)
        self.first = nn.Sequential(*first)
        
        self.attn_layers = int(-(math.log2(attn_maps) - (self.num_layers+1)))  #计算SA模块的位置
        
        if self.attn_layers == -1:                           #如果为256，则SA模块位于第1层后
            self.attn = Xiong_Self_Attn(filters[0],sn = self.snn)

        #SA模块和DBlock放在一起
        self.blocks1 = nn.ModuleList([])
        self.blocks2 = nn.ModuleList([])
        
        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            
            if self.attn_layers == ind:             #定义attn层
                self.attn = Xiong_Self_Attn(out_chan, sn = snn)
            
            block = DBlock( in_chan, out_chan , sn = self.snn, which_bn = which_bn ,activation = activation, dpout = self.dpout)
            if ind <= self.attn_layers:      #小于等于SA层则进入blocks1，极端情况下没有blocks2
                self.blocks1.append(block)
            else:
                self.blocks2.append(block)   #大于进入blocks2

        
        #定义最后一层,输出为Batchsize,1,1,1        
        last = []
        if self.snn:
            last.append(nn.utils.spectral_norm(nn.Conv2d(in_channels=filters[-1],out_channels=1, kernel_size=4, stride=1, padding=0, bias=False)))
        else:
            last.append(nn.Conv2d(in_channels=filters[-1], out_channels=1, kernel_size=4, 
                                  stride=1, padding=0, bias=False))
        #从4*4，变到1*1，缩小4倍，暂时不激活
        self.last = nn.Sequential(*last)
        
    def forward(self, x):   #只输入图片就行
        x = self.first(x)
        if self.attn_layers == -1:      #如果为-1（256），先执行SA，再进入残差网络
            x = self.attn(x)
            for n in self.blocks2:
                x = n(x)
        elif self.attn_layers>=0 :                         #大于等于0，先进入blocks1，再进SA，再进blocks2
            for m in self.blocks1:
                x = m(x)            
            if self.attn_maps <  self.image_size/2:  
                x = self.attn(x)              
                for n in self.blocks2:
                    x = n(x) 
        else :                                            #没有SA
            for n in self.blocks2:
                x = n(x)
                
        out = self.last(x)          #最后一层输出
           
        return out

In [None]:
#判别器数据增强（用于ada操作）
def random_hflip(tensor, prob):   #随机左右翻转
    if prob < random.random():       #小于一个随机数，返回原来的值
        return tensor
    return torch.flip(tensor, dims=(3,))

class AugWrapper(nn.Module):
    def __init__(self, D):
        super().__init__()
        self.D = D
    def forward(self, images, prob = 0., types = [], detach = False):
        
        if random.random() < prob:
            images = random_hflip(images, prob=0.5)    #随机左右翻转，概率50%
        
        if random.random() < prob:
            images = DiffAugment(images, types=types)  #随机一系列操作（不是单个操作),这一系列操作本身也具有随机性
        if detach:                                      #不用管
            images = images.detach()
        return self.D(images)

In [None]:
# 初始化权重函数，这里可能有问题（经过思考，也许不初始化默认权重比较好）
def weights_init(m):
    if type(m) == nn.ConvTranspose2d:
        nn.init.normal_(m.weight.data,0.0,0.02)
    elif type(m) == nn.BatchNorm2d:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif type(m) == nn.Conv2d:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    # elif type(m) == nn.InstanceNorm2d:
    #     nn.init.normal_(m.weight.data, 1.0, 0.02)
    #     nn.init.constant_(m.bias.data, 0)

#创建网络实例，传送到多cpu或者cuda，并初始化
netG = Generator(image_size = image_size, latent_dim = nz, last_dim = nc , network_capacity = capacity_G, attn_maps = sa_map_G, fmap_max = 512*8, 
                 snn = True ,which_bn = nn.InstanceNorm2d, activation = nn.ReLU(True), ngpu = ngpu).to(device)
netD = Discriminator(image_size = image_size, first_dim = nc, network_capacity = capacity_D, attn_maps = sa_map_D,  fmap_max = 512*8,
                     which_bn = nn.InstanceNorm2d, activation = nn.LeakyReLU(0.2, inplace=True) ,snn = True,ngpu = ngpu,dpout = dpout).to(device)
netD_aug = AugWrapper(netD).to(device)     #ada网络加入netD

if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

netG.apply(weights_init)
netD.apply(weights_init)
print(netG)             #打印一下两个网络
print(netD)

In [None]:
#损失函数和优化器设置
criterion = nn.BCEWithLogitsLoss()
optimizerG = optim.Adam(netG.parameters(), lr=lrG, betas=(beta1, beta2))
optimizerD = optim.Adam(netD.parameters(), lr=lrD, betas=(beta1, beta2))

#训练过程中的绘图
def gen_img_plot(model,test_input):
    predicton = model(test_input).detach().cpu().numpy().squeeze()
    fig = plt.figure(figsize=(16,16))     #这个也可以画在tensorboard里面
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow(predicton[i],cmap='Greys_r')   #(-1,1),听说不能绘图
        plt.axis('off')
    plt.show()

test_input = torch.randn(16,nz,1,1,device=device)  #这里定义的是16*100的随机数据，能进去G网络？因为G网络要求是1*100输入？？？

G_loss = []     #存放每个epoch生成器的损失值
D_loss = []     #存放每个epoch判别器的损失值
scaler_D = torch.cuda.amp.GradScaler()
scaler_G = torch.cuda.amp.GradScaler()

In [None]:
%load_ext tensorboard
%tensorboard --logdir=logs

In [None]:
#构建训练循环
for epoch in range(num_epochs):
    d_epoch_loss = 0    #这里来存计算过程中的每个batch的损失
    g_epoch_loss = 0
    penalty_epoch_loss = 0
    p = 0               #存开始的概率
    count = len(dataloader)
     #执行每个batch的内部
    for setp,(img,_) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, nz, 1, 1,device=device)
        #构建判别器的训练过程
        optimizerD.zero_grad()
        
        
    # 交叉熵损失
#         with autocast():
#             real_output = netD_aug(img,prob = p ,types = ['translation', 'cutout'])   #判别器输入真实图片，得到真实图片的预测real_output
#             d_real_loss = criterion(real_output,torch.ones_like(real_output))    #G网络中真实图片损失


#             gen_img = netG(random_noise)
#             fake_output = netD_aug(gen_img.detach(),prob = p ,types = ['translation', 'cutout'],detach = True) 
#             d_fake_loss = criterion(fake_output,torch.zeros_like(fake_output))    #G网络中假图片损失
#             d_loss = d_real_loss + d_fake_loss
            
#         scaler_D.scale(d_loss).backward()
#         scaler_D.step(optimizerD)
#         scaler_D.update()

#         #构建生成器的训练过程
#         optimizerG.zero_grad()
#         with autocast():
#           fake_output = netD_aug(gen_img,prob = p ,types = ['translation', 'cutout'])
#           g_loss = criterion(fake_output,torch.ones_like(fake_output))
#         scaler_G.scale(g_loss).backward()
#         scaler_G.step(optimizerG)
#         scaler_G.update()
# #     #softpluse 损失
#         with autocast():
#             real_output = netD_aug(img,prob = p ,types = ['translation', 'cutout'])   #判别器输入真实图片，得到真实图片的预测real_output
#             d_real_loss =  F.softplus(-real_output).mean()
        
#             gen_img = netG(random_noise)
#             fake_output = netD_aug(gen_img.detach(),prob = p ,types = ['translation', 'cutout'],detach = True) 
#             d_fake_loss = F.softplus(fake_output).mean()
#             d_loss =  d_real_loss + d_fake_loss
        
#         scaler_D.scale(d_loss).backward()
#         scaler_D.step(optimizerD)
#         scaler_D.update()

#         #构建生成器的训练过程
#         optimizerG.zero_grad()
#         with autocast():
#             fake_output = netD_aug(gen_img,prob = p ,types = ['translation', 'cutout'])
#             g_loss = F.softplus(-fake_output).mean()
#         scaler_G.scale(g_loss).backward()
#         scaler_G.step(optimizerG)
#         scaler_G.update()

    #hinge损失        
        with autocast():
#             real_output = netD(img)   #判别器输入真实图片，得到真实图片的预测real_output
            real_output = netD_aug(img,prob = p ,types = ['translation', 'cutout'])   #判别器输入真实图片，得到真实图片的预测real_output
            d_real_loss =  F.relu(1.-real_output).mean()
        
            gen_img = netG(random_noise)
#             fake_output = netD(gen_img.detach())   #判别器输入假图片，得到假图片的预测fake_output.注意由于训练判别器的过程中要保持生成器不变，所以要截断梯度
            fake_output = netD_aug(gen_img.detach(),prob = p ,types = ['translation', 'cutout'],detach = True) 
            #img.clone().detach()，通过d_epoch_loss控制概率
            d_fake_loss = F.relu(1.+fake_output).mean()
            d_loss =  d_real_loss + d_fake_loss
        
        scaler_D.scale(d_loss).backward()
        scaler_D.step(optimizerD)
        scaler_D.update()

        #构建生成器的训练过程
        optimizerG.zero_grad()
        with autocast():
#             fake_output = netD(gen_img)   #生成假图 
            fake_output = netD_aug(gen_img,prob = p ,types = ['translation', 'cutout'])   
            g_loss = -torch.mean(fake_output)
        scaler_G.scale(g_loss).backward()
        scaler_G.step(optimizerG)
        scaler_G.update() 
                            
        print("Epoch:", epoch,"Batch:",setp,"/",count)

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        #加入ada的概率控制
        if 0 <= d_epoch_loss < 1.0:    
            p = 1-d_epoch_loss
            p = min(0.8,p)
            
        G_loss.append(d_epoch_loss)
        D_loss.append(d_epoch_loss)
        print("Epoch:",epoch)
        writer.add_scalars('loss', {"Gloss":g_epoch_loss,'Dloss':d_epoch_loss}, global_step=epoch) 
        #将损失输入tensorbo
    if epoch % 20 == 0:           #看训练过程中的图片
        gen_img_plot(netG, test_input)
    if epoch % 50 == 0:           #看训练过程中的图片
        torch.save(netG.state_dict(),"netG42{}.pth".format(epoch))    #既保存了结构，也保存了参数
writer.close()




f1=open("G_loss.txt","w")     #保存至txt中，每次运行都会覆盖原来txt中的内容
for i in range(len(G_loss)):
    f1.write(str(G_loss[i]))
    f1.write("\n")
f1.close()
f2=open("D_loss.txt","w")
for i in range(len(D_loss)):
    f2.write(str(D_loss[i]))
    f2.write("\n")
f2.close()

In [None]:
# The following code is used for the inference stage

In [None]:
#验证模型
#假设生成100张图

#第一步，定义模型,并加载参数
netG = Generator(image_size = image_size, latent_dim = nz, last_dim = nc , network_capacity = capacity_G, attn_maps = sa_map_G, fmap_max = 512*8, 
                 snn = True,which_bn = nn.InstanceNorm2d, activation = nn.ReLU(True), ngpu = ngpu).to(device)
name = '512netG42-M150.pth'
# print(netG)
netG.load_state_dict(torch.load(name,map_location=torch.device('cpu')))
if os.path.exists('{}_img'.format(name)):
    pass
else:
    os.makedirs('{}_img'.format(name))  #创建文件夹

# #第二步，定义随机变量
test_input = torch.randn(50,nz,1,1,device=device)  #这里定义的是16*100的随机数据，能进去G网络？因为G网络要求是1*100输入？？？

# #第三步，验证
netG.eval()             #开始测试
for i in range(test_input.size(0)):
    with torch.no_grad():    #测试的时候梯度不需要改变 参数不需要更新 要保持良好的代码习惯哦~
        output = netG(test_input[i].unsqueeze(0))
        output_img = (output+1)/2   #变换到0-1空间中
#     print("Number:",i)
        #保存图片
    torchvision.utils.save_image(output_img.squeeze(),"{}_img\{}_{}.jpg".format(name,name,i))

In [None]:
#验证模型
#假设生成100张图

#第一步，定义模型,并加载参数
netG = Generator(image_size = image_size, latent_dim = nz, last_dim = nc , network_capacity = capacity_G, attn_maps = sa_map_G, fmap_max = 512*8, 
                 snn = True,which_bn = nn.InstanceNorm2d, activation = nn.ReLU(True), ngpu = ngpu).to(device)
model_name = os.listdir('model')
for j in range(len(model_name)):
    name = os.path.join('model',model_name[j])
    netG.load_state_dict(torch.load(name,map_location=torch.device('cpu')))
    if os.path.exists('{}_img'.format(model_name[j])):
        pass
    else:
        os.makedirs('{}_img'.format(model_name[j]))  #创建文件夹

    # #第二步，定义随机变量
    test_input = torch.randn(50,nz,1,1,device=device)  #这里定义的是16*100的随机数据，能进去G网络？因为G网络要求是1*100输入？？？

    # #第三步，验证
    netG.eval()             #开始测试
    for i in range(test_input.size(0)):
        with torch.no_grad():    #测试的时候梯度不需要改变 参数不需要更新 要保持良好的代码习惯哦~
            output = netG(test_input[i].unsqueeze(0))
            output_img = (output+1)/2   #变换到0-1空间中
#     print("Number:",i)
        #保存图片
        torchvision.utils.save_image(output_img.squeeze(),"{}_img\{}_{}.jpg".format(model_name[j],model_name[j],i))