In [38]:
from torch import nn
import torch
import torch.nn.functional as F

'''
VAE模型

'''

class VAE(nn.Module):
    def __init__(self, hiddens=[16,32,128,256], z_dim=128,image_size=128,ch=3):
        # 调用父类方法初始化模块的state
        super(VAE, self).__init__()

        prev_ch = ch
        modules = []
        cur_image_size = image_size
        # 编码器 ： [bs,ch, input_dim] => [bs,ch, z_dim]
        for cur_ch in hiddens:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(prev_ch,cur_ch,kernel_size=3,stride=2,padding=1), #stride=2 图片每次缩小一半
                    nn.BatchNorm2d(cur_ch),
                    nn.ReLU()
                )
            )
            '''通道数每次卷积X2，图片大小每次 /2
                类似UNet的操作
            '''
            prev_ch = cur_ch
            cur_image_size //= 2
        self.encoder = nn.Sequential(*modules)
        self.mean_linear = nn.Linear(prev_ch * cur_image_size * cur_image_size,z_dim)
        self.var_linear = nn.Linear(prev_ch * cur_image_size * cur_image_size,z_dim)

        # 解码器 ： [bs,ch, z_dim] => [bs,ch, input_dim]
        modules = []
        #prev_ch = 256
        self.decoder_projection = nn.Linear(z_dim,prev_ch * cur_image_size * cur_image_size)
        self.decoder_in_chw = (prev_ch, cur_image_size, cur_image_size)
        for i , cur_ch in enumerate( hiddens[::-1] ):
            if i == 0:
                pass
            else:
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(prev_ch,cur_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
                        nn.BatchNorm2d(cur_ch),
                        nn.ReLU()
                    )
                )
            prev_ch = cur_ch
        #(1,256,8,8) 做完反卷积(3次反卷积 分别是256to128 128to32 32to16 ) 变成(1,16,64,64)
        modules.append(
            nn.Sequential(
                #先用反卷积把(1,16,64,64) 变成(1，16,128,128)
                #再卷积成三通道
                nn.ConvTranspose2d(prev_ch,prev_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
                nn.BatchNorm2d(cur_ch),
                nn.ReLU(),
                nn.Conv2d(cur_ch,3,kernel_size=3,stride=1,padding=1),
                nn.ReLU()
            )
        )
        
        self.decoder = nn.Sequential(*modules)
        
    def forward(self, x):
        """
        向前传播部分, 在model_name(inputs)时自动调用
        """
        # encoder
        mu, log_var = self.encode(x)
        
        # reparameterization trick
        sampled_z = self.reparameterization(mu, log_var)
        sampled_z = self.decoder_projection(sampled_z)
        # reshape
        sampled_z = torch.reshape(sampled_z,(-1, *self.decoder_in_chw))
        # decoder
        #print(sampled_z.shape)
        x_hat = self.decode(sampled_z)
        #print(x_hat.shape)
        return x_hat, mu, log_var

    def encode(self, x):
        """
        encoding part
        :param x: input image
        :return: mu and log_var
        """
        x = self.encoder(x)
        x = torch.flatten(x,1) #把 (bs,256,h, w) 压平成 (bs, ...)
        mu = self.mean_linear(x)
        log_var = self.var_linear(x)

        return mu, log_var

    def reparameterization(self, mu, log_var): #重参数化采样z
        """
        Given a standard gaussian distribution epsilon ~ N(0,1),
        we can sample the random variable z as per z = mu + sigma * epsilon
        :param mu:
        :param log_var:
        :return: sampled z
        """
        sigma = torch.exp(log_var * 0.5) #标准差sigma, 方差的log log_var
        eps = torch.randn_like(sigma)
        return mu + sigma * eps  # 这里的“*”是点乘的意思

    def decode(self, z):
        """
        Given a sampled z, decode it back to image
        :param z:
        :return:
        """
        x_hat = self.decoder(z)
        #x_hat = torch.sigmoid(self.fc5(h))  # 图片数值取值为[0,1]，不宜用ReLU
        return x_hat



In [41]:
import sys
from PIL import Image
import glob
import time
from pathlib import Path
from typing import Iterable,Optional
import math
import torch
#import torch.multiprocessing
#torch.multiprocessing.set_sharing_strategy('file_system')
import torch.nn as nn
from torchvision.utils import save_image
import torchvision
import argparse
import os
#import timm
#from timm.utils import accuracy
from torch.utils.tensorboard import SummaryWriter
from util import misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_args_parser():
    parser = argparse.ArgumentParser(description="Variational Auto-Encoder Example")
    parser.add_argument('--batch_size',default=32,type=int,help='Batch size per GPU (effective batch size is batch_size*accum_iter* #gpus)')
    parser.add_argument('--epochs',default=20,type=int)
    parser.add_argument('--accum_iter',default=1,type=int)
    #Model parameters
    parser.add_argument('--image_size', type=int, default=128 , metavar='N', help='Image size')
    parser.add_argument('--z_dim', type=int, default=128, metavar='N', help='the dim of latent variable z(default: 20)')

    parser.add_argument('--input_channel', type=int, default=3, metavar='N', help='input channel(default: 1 for MNIST)')

    #Optimizer parameters
    parser.add_argument('--weight_decay',type=float,default=0.0001)
    parser.add_argument('--lr',type=float,default=0.0001,metavar='LR')
    parser.add_argument('--root_path',default='D:\\jiao\\datasets\\celeba')
    parser.add_argument('--output_dir',default='./output_dir_pretrained',help='path to save,empty for no saving')
    parser.add_argument('--log_dir',default='./output_dir_pretrained',help='path to tensorboard log')
    
    parser.add_argument('--resume',default='',help='resume from checkpoint')
    parser.add_argument('--start_epoch',default=0,type=int,metavar='N')
    parser.add_argument('--num_workers',default=5,type=int)
    parser.add_argument('--pin_mem',action='store_true')
    parser.add_argument('--no_pin_mem',action='store_false',dest='pin_mem')
    parser.set_defaults(pin_mem=True)
    return parser
'''创建预处理的transform'''
def build_transform(is_train,args):
    return torchvision.transforms.Compose([
        torchvision.transforms.CenterCrop(168),
        torchvision.transforms.Resize((args.image_size,args.image_size)),
        torchvision.transforms.ToTensor()
    ])
 
'''创建数据集 返回dataset'''
def build_dataset(is_train,args):
    transform = build_transform(is_train,args)
    path = os.path.join(args.root_path,'train' if is_train else 'test')
    dataset = torchvision.datasets.ImageFolder(path,transform= transform)
    info = dataset.find_classes(path)
    #print(f"finding classes from {path}: {info[0]}")
    print(f"mapping classes from {path} to indexes:{info[1]}")
    return dataset


def vae_loss(x_hat, x, mu, log_var):
    """
    Calculate the loss. Note that the loss includes two parts.
    :param x_hat:
    :param x:
    :param mu:
    :param log_var:
    :return: total loss, BCE and KLD of our model
    """
    # 1. the reconstruction loss. 重建损失
    # We regard the MNIST as binary classification
    #BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')#MINST等二值图像可以用交叉熵
    #jtq20240214 非二值图像用均方误差
    BCE = F.mse_loss(x_hat , x , reduction='sum')

    # 2. KL-divergence KL散度损失
    # D_KL(Q(z|X) || P(z)); calculate in closed form as both dist. are Gaussian
    # here we assume that \Sigma is a diagonal matrix, so as to simplify the computation
    # D_KL(Q(z|X) || N(0,1)) = 0.5*( -1 - log(sigma^2) + mu^2 + sigma^2)
    # log_var = log(sigma^2)
    KLD = 0.5 * torch.sum(torch.exp(log_var) + torch.pow(mu, 2) - 1. - log_var)

    # 3. total loss 总损失 = 重建损失 + KL散度损失
    loss = BCE + KLD 
    return loss, BCE, KLD
''' 验证函数
    输入：
    输出：
'''
@torch.no_grad()
def evaluate(data_loader,model,device,epoch):
    criterion = vae_loss
    metric_logger = misc.MetricLogger(delimiter=" ")
    header = 'Test:'
    model = model.to(device)
    model.eval()
    #test_avg_loss = 0.0
    #下面这段话基本等价于 for (images, targets) in data_loader：
    for batch in metric_logger.log_every(data_loader,
                                         100, #打印间隔
                                         header): #标题
        images = batch[0]
        target = batch[-1]
        images = images.to(device,non_blocking=True)
        target = target.to(device,non_blocking=True)
        # 前向传播
        test_x_hat, test_mu, test_log_var = model(images)
        
        # 损失函数值
        test_loss, test_BCE, test_KLD = criterion(test_x_hat, images, test_mu, test_log_var)
        #test_avg_loss += test_loss
        batch_size = images.shape[0]
        metric_logger.update(loss = test_loss.item())
        
    # 对和求平均，得到每一张图片的平均损失
    #test_avg_loss /= len(mnist_test.dataset)    
    '''测试随机生成的隐变量'''
    # 随机从隐变量的分布中取隐变量
    z = torch.randn(32, args.z_dim).to(device)  # 每一行是一个隐变量，总共有batch_size行
    z = model.decoder_projection(z)
    # reshape
    z = torch.reshape(z,(-1, *model.decoder_in_chw))
    
    # 对隐变量重构
    random_res = model.decode(z)
    # 保存重构结果
    save_image(random_res, f"{args.output_dir}/random_sampled-{epoch}.png" )

    
    metric_logger.synchronize_between_processes()
    print('loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss))
    return {k:meter.global_avg for k,meter in metric_logger.meters.items()}


'''
    训练函数
'''
def train_one_epoch(model:torch.nn.Module,criterion:torch.nn.Module,
                    data_loader:Iterable,optimizer:torch.optim.Optimizer,
                    device:torch.device,epoch:int,loss_scaler,max_norm: float=0,
                    log_writer=None,args=None):
    model.train(True)
    print_freq = 2
    accum_iter = args.accum_iter
    #print("in train_one_epoch")
    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))
    for data_iter_step,(samples,targets) in enumerate(data_loader):
        samples = samples.to(device,non_blocking=True)
        targets = targets.to(device,non_blocking=True)
        
        #print("input_dim:",args.input_dim)
        #print("samples shape:",samples.shape)
        x_hat, mu, log_var = model(samples)
        warmup_lr = args.lr
        optimizer.param_groups[0]["lr"] = warmup_lr
        
        loss , _ , _ = criterion(x_hat,samples, mu,log_var)
        loss /= accum_iter
        
        loss_scaler(loss,optimizer,clip_grad=max_norm, 
                    parameters=model.parameters(),create_graph=False,
                    update_grad=(data_iter_step+1)%accum_iter == 0) #训练每accum_iter个batch才更新梯度
        loss_value = loss.item()
        if (data_iter_step+1)%accum_iter == 0:
            optimizer.zero_grad()
        if not math.isfinite(loss_value):
            print(f"loss is {loss_value}, stopping training")
            sys.exit(1)
        if log_writer is not None and (data_iter_step+1)% (accum_iter*100) == 0 :
            epoch_1000x = int((data_iter_step/len(data_loader)+epoch)*1000)
            log_writer.add_scalar('loss',loss_value,epoch_1000x)
            log_writer.add_scalar('lr',warmup_lr,epoch_1000x)
            print(f"Epoch: {epoch}, Step: {data_iter_step}, Loss: {loss}, Lr: {warmup_lr}")


def main(args,mode='train',test_image_path=''):
    print(f"当前mode: {mode}")
    if mode =='train':
        #构建批次
        dataset_train = build_dataset(is_train=True,args=args)
        dataset_val = build_dataset(is_train=False,args=args)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        data_loader_train = torch.utils.data.DataLoader(
            dataset=dataset_train,sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=True,
        )
        data_loader_val = torch.utils.data.DataLoader(
            dataset=dataset_val,sampler=sampler_val,
            batch_size=32,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False,
        )
        
        #构建模型
        model = VAE(z_dim = args.z_dim,image_size=args.image_size,ch=args.input_channel)
        #model = VAE(args.input_dim,args.hid_dim,args.z_dim)
        model = model.to(device)
        n_parameters = sum([p.numel() for p in model.parameters() if p.requires_grad])

        print(f"number of trainable parameters(M):{n_parameters/1.e6:.2f}") #f-string保留两位小数{xxx:.2f}
        criterion = vae_loss
        
        #weight_decay就是对损失函数做L2正则化，防止过拟合
        optimizer = torch.optim.AdamW(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)     
        #用tensorboard记录日志
        os.makedirs(args.log_dir,exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
        #lossScaler用来反传梯度用的
        loss_scaler = NativeScaler()
        
        #读入已有的模型 resume为空字符串 则不会读取，如果传入时pth文件，则会读取原来的模型
        #读取进来时args里epoch会+1
        misc.load_model(args=args,model_without_ddp=model,optimizer=optimizer,loss_scaler=loss_scaler)
        for epoch in range(args.start_epoch,args.epochs): #start_epoch开始训练
            print(f"Epoch {epoch}")
            print(f"length of data_loader_train is {len(data_loader_train)}") #几个batch

            if epoch % 1 == 0:
                print("Evaluating...")
                model.eval()
                test_stats = evaluate(data_loader_val,model,device,epoch)
                print(f"loss on the {len(dataset_val)} test images {test_stats['loss']:.2f}")
                if log_writer is not None:
                    ''' 
                        add_scalar(tag, scalar_value, global_step=None, walltime=None) 
                        add_scalar:记录标量函数,参数:
                        tag：图的名称 scalar_value：记录的值 global_step：x轴
                    '''
                    log_writer.add_scalar('perf/test_loss',test_stats['loss'],epoch)
                model.train()    
            print("Training...")
            train_stats = train_one_epoch(
                model,criterion,data_loader_train,
                optimizer,device,epoch,#epoch+1, #为什么要+1？
                loss_scaler,None,
                log_writer=log_writer,args=args
            )
            if args.output_dir:
                print("Saving checkpoint...")
                misc.save_model(args=args,model=model,model_without_ddp=model,optimizer=optimizer,
                               loss_scaler=loss_scaler,epoch=epoch)
            #break
        
           
'''main'''      
if __name__ == '__main__':
#     z_dim = 64 #隐空间维度
#     hid_dim = 512 #encoder和decoder中间层的维度
#     in_dim = 128 * 128 
#     #out_dim = 28*28 #图片维度
#     ch = 3
#     bs =128
#     x = torch.randn(bs,ch,128,128)
#     model = VAE(z_dim = z_dim,image_size=128,ch=3)
#     re_x, mu, log_var = model(x)
#     print(re_x.shape)
    args = get_args_parser()
    args = args.parse_args(args=['--batch_size','256','--epochs','100','--num_workers','2','--resume','./output_dir_pretrained/checkpoint-29.pth'])
#     dataset = build_dataset(is_train=True,args=args)
#     print(dataset[1])
    main(args = args,mode='train')     


当前mode: train
mapping classes from D:\jiao\datasets\celeba\train to indexes:{'face': 0}
mapping classes from D:\jiao\datasets\celeba\test to indexes:{'face': 0}
number of trainable parameters(M):6.99
Resume checkpoint ./output_dir_pretrained/checkpoint-4.pth
With optim & sched!
Epoch 5
length of data_loader_train is 790
Evaluating...
Test: [0/9] eta: 0:00:31 loss: 18838.9277 (18838.9277) time: 3.5117 data: 3.3702 max mem: 1909
Test: [8/9] eta: 0:00:00 loss: 18838.9277 (18438.4068) time: 0.4743 data: 0.4389 max mem: 1909
Test: Total time: 0:00:04 (0.5179 s / it)
loss 18438.407
loss on the 276 test images 18438.41
Training...
log_dir: ./output_dir_pretrained
Epoch: 5, Step: 99, Loss: 155902.671875, Lr: 0.0001
Epoch: 5, Step: 199, Loss: 160758.15625, Lr: 0.0001
Epoch: 5, Step: 299, Loss: 155493.25, Lr: 0.0001
Epoch: 5, Step: 399, Loss: 151976.40625, Lr: 0.0001
Epoch: 5, Step: 499, Loss: 149445.890625, Lr: 0.0001
Epoch: 5, Step: 599, Loss: 150650.953125, Lr: 0.0001
Epoch: 5, Step: 699, Los

Test: Total time: 0:00:02 (0.3061 s / it)
loss 16355.861
loss on the 276 test images 16355.86
Training...
log_dir: ./output_dir_pretrained
Epoch: 15, Step: 99, Loss: 139127.09375, Lr: 0.0001
Epoch: 15, Step: 199, Loss: 139026.5, Lr: 0.0001
Epoch: 15, Step: 299, Loss: 138588.5, Lr: 0.0001
Epoch: 15, Step: 399, Loss: 132955.125, Lr: 0.0001
Epoch: 15, Step: 499, Loss: 132687.15625, Lr: 0.0001
Epoch: 15, Step: 599, Loss: 133745.734375, Lr: 0.0001
Epoch: 15, Step: 699, Loss: 141462.375, Lr: 0.0001
Saving checkpoint...
Epoch 16
length of data_loader_train is 790
Evaluating...
Test: [0/9] eta: 0:00:19 loss: 16625.9492 (16625.9492) time: 2.1382 data: 2.1053 max mem: 1909
Test: [8/9] eta: 0:00:00 loss: 16637.4023 (16308.3953) time: 0.2598 data: 0.2367 max mem: 1909
Test: Total time: 0:00:02 (0.2977 s / it)
loss 16308.395
loss on the 276 test images 16308.40
Training...
log_dir: ./output_dir_pretrained
Epoch: 16, Step: 99, Loss: 137840.875, Lr: 0.0001
Epoch: 16, Step: 199, Loss: 134635.0, Lr: 0.

Epoch: 25, Step: 699, Loss: 130889.453125, Lr: 0.0001
Saving checkpoint...
Epoch 26
length of data_loader_train is 790
Evaluating...
Test: [0/9] eta: 0:00:19 loss: 16070.5625 (16070.5625) time: 2.1938 data: 2.1578 max mem: 1909
Test: [8/9] eta: 0:00:00 loss: 16070.5625 (15658.8483) time: 0.2691 data: 0.2410 max mem: 1909
Test: Total time: 0:00:02 (0.3063 s / it)
loss 15658.848
loss on the 276 test images 15658.85
Training...
log_dir: ./output_dir_pretrained
Epoch: 26, Step: 99, Loss: 130021.390625, Lr: 0.0001
Epoch: 26, Step: 199, Loss: 128552.21875, Lr: 0.0001
Epoch: 26, Step: 299, Loss: 133231.40625, Lr: 0.0001
Epoch: 26, Step: 399, Loss: 127433.6328125, Lr: 0.0001
Epoch: 26, Step: 499, Loss: 130199.84375, Lr: 0.0001
Epoch: 26, Step: 599, Loss: 128951.0078125, Lr: 0.0001
Epoch: 26, Step: 699, Loss: 126919.84375, Lr: 0.0001
Saving checkpoint...
Epoch 27
length of data_loader_train is 790
Evaluating...
Test: [0/9] eta: 0:00:19 loss: 16013.4531 (16013.4531) time: 2.1533 data: 2.1156 max

In [14]:
hiddens=[16,32,128,256]
print(hiddens)

print( hiddens)

[16, 32, 128, 256, 3]
[16, 32, 128, 256, 3]
