In [2]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from utils.DatasetUtils import DatasetUtils 

data_utils= DatasetUtils()

# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
eval_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

print(f"训练集的数据个数: {len(train_dataset)}")
print(f"验证集的数据个数: {len(eval_dataset)}")

训练集的数据个数: 60000
验证集的数据个数: 10000


In [3]:
# 模型组网

from model.BetaScheduler import LinearBetaScheduler
from model.UNet import UNet
from model.CD_Model import CD_Model

input_shape=data_utils.get_data_shape(train_dataset)

model=CD_Model(
    UNet(
        input_shape=input_shape,
        num_classes=10,
        init_features=64,
        embed_dim=128
    ),
    LinearBetaScheduler()
)

print(f"{input_shape=}")

input_shape=(1, 28, 28)


In [4]:
# 模型准备
from torch.utils.data import DataLoader
from torch import nn
from torch import optim

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=64)

loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
import torch.profiler
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],  # 分析 CPU 和 CUDA 活动
    schedule=torch.profiler.schedule(
        wait=1,  # 前1步不采样
        warmup=1,  # 第2步作为热身，不计入结果
        active=3,  # 采集后面3步的性能数据
        repeat=2,
    ),  # 重复2轮
    on_trace_ready=torch.profiler.tensorboard_trace_handler(
        "./logs"
    ),  # 保存日志以供 TensorBoard 可视化
    record_shapes=True,  # 记录输入张量的形状
    profile_memory=True,  # 分析内存分配
    with_stack=True,  # 记录操作的调用堆栈信息
) as profiler:

    iterator=iter(train_loader)
    for step in range(10):
        datas, labels = next(iterator)
        batch_size = datas.shape[0]
        time = torch.randint(
                    1, model.timesteps, (batch_size,)
                ).long()
        outputs = model(datas, time,labels)
        loss = outputs.sum()
        loss.backward()

        profiler.step()  # 更新 profiler 的步骤