In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from utils.DatasetUtils import DatasetUtils

data_utils = DatasetUtils()

# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = MNIST(root="./data", train=True, download=True, transform=transform)
eval_dataset = MNIST(root="./data", train=False, download=True, transform=transform)

print(f"训练集的数据个数: {len(train_dataset)}")
print(f"验证集的数据个数: {len(eval_dataset)}")

In [None]:
# 模型组网

# from model.base.UNet import UNet
# from model.Diffusion.UNet import UNet
from model.Diffusion.UNet_v2 import UNet

from model.BetaScheduler import LinearBetaScheduler
from model.CD_Model import CD_Model
from utils.ModelWorker.CDModelWorker import CDModelWorker
import torch
from torchkeras import summary

input_shape = (1, 28, 28)

model = CD_Model(
    UNet(
        input_shape=input_shape,
        init_features=32,
        num_classes=10,
        embed_dim=128,
        num_heads=1
    ),
    LinearBetaScheduler(),
)

model_worker = CDModelWorker(model)

print(f"{input_shape=}")

time = torch.tensor([0], dtype=torch.long)
condition = torch.tensor([0], dtype=torch.long)
model_info = summary(model, input_shape=input_shape, time=time, condition=condition)

In [None]:
# 模型准备
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from model.MinSNRLoss import MinSNRLoss

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=64)

# loss = nn.MSELoss()
loss = MinSNRLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler= optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

In [None]:
from torch.profiler import profile, schedule, tensorboard_trace_handler, ProfilerActivity

def get_profiler(
    log_dir: str = "./logs",
    use_cuda: bool = True,
    wait_steps: int = 1,
    warmup_steps: int = 1,
    active_steps: int = 3,
    record_shapes: bool = True,
    profile_memory: bool = True,
    with_stack: bool = True
) -> profile:
    """
    获取PyTorch Profiler 实例

    Args:
        log_dir (str): TensorBoard 日志保存目录
        use_cuda (bool): 是否启用 CUDA 分析
        wait_steps (int): 等待步数（不分析）
        warmup_steps (int): 预热步数（准备分析）
        active_steps (int): 活跃分析步数
        record_shapes (bool): 是否记录张量形状
        profile_memory (bool): 是否分析内存
        with_stack (bool): 是否记录调用栈

    Returns:
        torch.profiler.profile: 配置好的 Profiler 实例
    """
    activities = [ProfilerActivity.CPU]
    if use_cuda and torch.cuda.is_available():
        activities.append(ProfilerActivity.CUDA)
    
    prof=profile(
        activities=activities,
        schedule=schedule(
            wait=wait_steps,
            warmup=warmup_steps,
            active=active_steps
        ),
        on_trace_ready=tensorboard_trace_handler(log_dir),
        record_shapes=record_shapes,
        profile_memory=profile_memory,
        with_stack=with_stack
    )

    return prof

In [None]:
def take_profile(
    model,
    profiler,
    optimizer,
    criterion,
    train_loader: DataLoader,
    steps=5,
):
    # 设置模型为训练模式
    model.train()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    mse_flag = False
    if isinstance(criterion, nn.MSELoss):
        mse_flag = True

    count=0
    for datas, labels in train_loader:
        datas, labels = datas.to(device), labels.to(device)
        batch_size = datas.size(0)

        # 清空梯度
        optimizer.zero_grad()

        # 随机选择时间步
        time = torch.randint(
            1, model.timesteps + 1, (batch_size,), device=device
        ).long()

        # 正向过程
        xt, noise = model.forward_process(datas, time)

        # 反向过程
        pred_noise = model(x=xt, time=time, condition=labels)

        # 计算损失
        if mse_flag:
            loss = criterion(pred_noise, noise)
        else:
            alpha_bar_t = model.scheduler.get_alpha_bar(time)
            loss = criterion(pred_noise, noise, alpha_bar_t)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        profiler.step()

        count+=1
        if count>=steps:
            break

In [None]:
prof=get_profiler()

take_profile(
    model,
    prof,
    optimizer,
    loss,
    train_loader,
    steps=5
)