### 条件扩散

In [None]:
from torchvision import transforms

from utils.RFID_Dataset import RFID_Dataset
from utils.DatasetUtils import DatasetUtils

data_utils = DatasetUtils()

transform = transforms.Compose([
])

train_dir = r"data\RFID\dataset\train"
eval_dir = r"data\RFID\dataset\eval"

# 加载数据集
train_dataset = RFID_Dataset(
    train_dir,
    T=32,
    step=1,
    transform=transform,
)

eval_dataset = RFID_Dataset(
    eval_dir,
    T=32,
    step=1,
    transform=transform,
)

print(f"训练集的数据个数: {len(train_dataset)}")
print(f"验证集的数据个数: {len(eval_dataset)}")

In [None]:
# 模型组网

import torch
from model.BetaScheduler import LinearBetaScheduler
from model.UNet import UNet
from model.CD_Model import CD_Model
from utils.ModelWorker.CDModelWorker import CDModelWorker
from torchkeras import summary

input_shape=data_utils.get_data_shape(train_dataset)

model=CD_Model(
    UNet(
        input_shape=input_shape,
        num_classes=3,
        init_features=64,
        embed_dim=64
    ),
    LinearBetaScheduler()
)

model_worker=CDModelWorker(model)

print(f"{input_shape=}")

time=torch.Tensor([0])
condition=torch.Tensor([0])
model_info=summary(model,input_shape=input_shape,time=time,condition=condition)

In [None]:
# 模型准备
from torch.utils.data import DataLoader
from torch import nn
from torch import optim

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=8)

loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


In [None]:
# 模型训练

model_worker.train(
    criterion=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    eval_loader=eval_loader,
    epochs=10,
)

In [None]:
# 模型评估

model_worker.evaluate(
    eval_loader=eval_loader,
    criterion=loss,
)

In [None]:
# 保存模型
model_worker.save('./output/ACR_CD_model.pth')

In [None]:
# 加载模型
model_worker.load('./output/ACR_CD_model.pth')

In [None]:
# 生成数据

condition = 2
datas=model_worker.generate_sample(20,condition,add_noise=True)

In [None]:
# 保存数据

from utils.CSVUtils import save_samples_as_csv

save_samples_as_csv(datas, output_dir=f"./output/data/{condition}")