### 人类活动识别——条件扩散

In [None]:
# 数据预处理
from torchvision import transforms
from model.Normalization import RobustNorm

transform=transforms.Compose([
    RobustNorm(-68.0, 68.0),
])

In [None]:
# 加载数据集
from model.RFID_Dataset import RFID_Dataset

train_dir = r"data\RFID_multi_628\dataset\train"
eval_dir = r"data\RFID_multi_628\dataset\eval"

train_dataset = RFID_Dataset(
    train_dir,
    T=32,
    step=1,
    num_channels=3,
    transform=transform,
)

eval_dataset = RFID_Dataset(
    eval_dir,
    T=32,
    step=1,
    num_channels=3,
    transform=transform,
)

print(f"训练集的数据个数: {len(train_dataset)}")
print(f"验证集的数据个数: {len(eval_dataset)}")

In [None]:
# 模型组网

from model.base.UNet import UNet
# from model.v1.UNet import UNet
# from model.v2.UNet import UNet
# from model.v3.UNet import UNet
# from model.v4.UNet import UNet

from model.BetaScheduler import LinearBetaScheduler
from model.CD_Model import CD_Model
from model.ModelWorker.CDModelWorker import CDModelWorker
import torch
from torchkeras import summary

input_shape = (3, 32, 12)

model = CD_Model(
    UNet(
        input_shape=input_shape,
        init_features=64,
        embed_dim=128,
        num_heads=1,
        num_groups=16,
    ),
    LinearBetaScheduler(timesteps=1000,beta_end=0.02),
    num_classes=6,
    embed_dim=128,
    enable_guidance=True,
)

model_worker = CDModelWorker(model)

print(f"{input_shape=}")

time = torch.tensor([0], dtype=torch.long)
condition = torch.tensor([0], dtype=torch.long)
model_info = summary(model, input_shape=input_shape, time=time, condition=condition)

In [None]:
# 模型准备
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from model.Loss import *
from model.LightningModel.CDPLMpdel import CDPLModel

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=16)

# loss = nn.MSELoss()
loss = MinSNRLoss()
# loss=SigmoidLoss()

pl_model = CDPLModel(
    model,
    loss
)

In [None]:
# 构建PLTrainer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint,EarlyStopping

trainer = pl.Trainer(
    max_epochs=100,
    min_epochs=5,
    logger=True,
    callbacks=[
        ModelCheckpoint(save_weights_only=True),
        # EarlyStopping(monitor="val/loss", patience=5),
        # EarlyStopping(monitor="val/loss", patience=5,mode="min"),
    ],
    default_root_dir="./output/HAR",
)

In [None]:
# 模型训练

trainer.fit(
    pl_model,
    train_loader,
    eval_loader,
)
best_model_path=trainer.checkpoint_callback.best_model_path
print(best_model_path)

In [None]:
# 模型评估

trainer.validate(
    pl_model,
    eval_loader,
)

In [None]:
# 加载PLModel

pl_model=CDPLModel.load_from_checkpoint(
    best_model_path, 
    model=model, 
    criterion=loss,
)

In [None]:
# 模型评估-时间步序列
import torch
from torch import nn

sequence=torch.linspace(0, 1000,10+1,dtype=torch.long).tolist()[1:]

# eval_loss=nn.MSELoss()
eval_loss=ConstantLoss()
loss_group=model_worker.evaluate_sequence(
    eval_loader=train_loader,
    criterion=eval_loss,
    time=sequence,
    verbose=1
)

print(sequence)

In [None]:
for item in loss_group.items():
    print(f"{item[0]:4d}: {item[1]:.6f}")
from utils.DataUtils.Visualization import plot_curves,plot_scatter

plot_scatter(
    loss_group,
)

In [None]:
# 保存模型
model_worker.save('./output/HAR_CD/base/HAR_CD.pth')

In [None]:
# 加载模型
model_worker.load('./output/HAR_CD/base/HAR_CD.pth')

In [None]:
# DDPM采样

from model.RFID_Dataset import save_samples

num_classes = 6
for _ in range(3):
    for i in range(num_classes):
        # 生成数据
        condition = i
        datas = model_worker.generate_sample_batch(
            20,
            condition,
            guidance_scale=2,
        )

        # 保存数据
        save_samples(
            datas, 
            output_dir=f"./output/base/{condition}",
            merge=True,
        )

In [None]:
# DDIM采样

from model.RFID_Dataset import save_samples

num_classes = 6
for i in range(num_classes):
    # 生成数据
    condition = i
    cond = torch.full((20,), condition, dtype=torch.long)
    datas = model_worker.generate_sample_DDIM(
        cond,
        time=model_worker.get_linear_sampling_sequence(50),
        eta=0.0,
        guidance_scale=2,
    )

    # 保存数据
    save_samples(
        datas, 
        output_dir=f"./output/base_DDIM/{condition}",
        merge=True,
    )