### 条件扩散

In [1]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from utils.DatasetUtils import DatasetUtils

data_utils= DatasetUtils()

# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
eval_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

print(f"训练集的数据个数: {len(train_dataset)}")
print(f"验证集的数据个数: {len(eval_dataset)}")

训练集的数据个数: 60000
验证集的数据个数: 10000


In [None]:
# 查看数据集的样例

data_utils.show_image_simple(train_dataset,count=4)

In [2]:
# 模型组网

import torch
from model.BetaScheduler import LinearBetaScheduler
from model.UNet import UNet
from model.CD_Model import CD_Model
from utils.ModelWorker.CDModelWorker import CDModelWorker
from torchkeras import summary

input_shape=data_utils.get_data_shape(train_dataset)

model=CD_Model(
    UNet(
        input_shape=input_shape,
        init_features=32,
        num_classes=10,
        embed_dim=128
    ),
    LinearBetaScheduler()
)

model_worker=CDModelWorker(model)

print(f"{input_shape=}")

time=torch.tensor([0],dtype=torch.long)
condition=torch.tensor([0],dtype=torch.long)
model_info=summary(model,input_shape=input_shape,time=time,condition=condition)

input_shape=(1, 28, 28)
--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
PositionalEncoding-1                        [-1, 64]                    0
Linear-2                                    [-1, 32]                2,080
SiLU-3                                      [-1, 32]                    0
Linear-4                                    [-1, 64]                2,112
Embedding-5                                 [-1, 32]                  320
SiLU-6                                      [-1, 32]                    0
Linear-7                                    [-1, 64]                2,112
Identity-8                                 [-1, 128]                    0
Conv2d-9                            [-1, 32, 28, 28]                  320
GELU-10                             [-1, 32, 28, 28]                    0
GroupNorm-11                        [-1, 32, 28, 28]                   64
Conv2d-12    

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# 模型准备
from torch.utils.data import DataLoader
from torch import nn
from torch import optim

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=64)

loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
# 模型训练

model_worker.train(
    criterion=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    eval_loader=eval_loader,
    epochs=10,
)

In [None]:
# 模型评估

model_worker.evaluate(
    eval_loader=eval_loader,
    criterion=loss,
)

In [None]:
# 保存模型
model_worker.save('./output/HDR_CD_model/HDR_CD_model.pth')

In [3]:
# 加载模型
model_worker.load('./output/HDR_CD_model/HDR_CD_model.pth')

In [None]:
datas=model_worker.generate_sample(count=16,condition=1,add_noise=True)

In [None]:
datas=datas.cpu()
data_utils.show_image_batch(datas)