In [None]:
"""
依赖自检
"""

import sys
import importlib.util

print(f"{'='*10} 执行依赖自检 {'='*10}")

required_packages = [
    ("torch", "PyTorch" ),
    ("torchvision", "TorchVision"),
    ("diffusers", "Diffusers"),
    ("matplotlib", "Matplotlib")
]

missing_packages = []

for package_name, display_name in required_packages:
    if importlib.util.find_spec(package_name) is None:
        missing_packages.append(package_name)
    else:
        module = __import__(package_name)
        version = getattr(module, '__version__', '未知版本')
        print(f"{display_name}: {version}")

if missing_packages:
    print(f"\n【ERROR】缺少以下依赖包: {', '.join(missing_packages)}")
    sys.exit(1)
else:
    try:

        import os
        os.environ["PYTORCH_MPS_PREFER_METAL"] = "1"
        
        import torch
        import torchvision
        from torch import nn
        from torch.utils.data import DataLoader
        from diffusers import UNet2DModel
        from matplotlib import pyplot as plt
    except Exception as e:
        print(f"【ERROR】导入发生未知错误")
        sys.exit(1)

print(f"{'='*10} 依赖自检完成 {'='*10}")


# 提升 MPS 上的 matmul 精度/性能
torch.set_float32_matmul_precision("high")


In [None]:
"""
Device 自检
"""

try:
    if torch.backends.mps.is_available():
        device_name = "mps"
        if torch.backends.mps.is_built():
            print("【INFO】Use Apple Silicon (MPS)")
    elif torch.cuda.is_available():
        device_name = "cuda"
        print("【INFO】Use NVIDIA")
    else:
        device_name = "cpu"
        print("【INFO】Use CPU")

    device = torch.device(device_name)
    x = torch.ones(1).to(device)  # 在CPU中创建张量，并将其移动至device，赋值给x
    print(f"{device} 可以使用")

except Exception as e:
    print(f"【ERROR】{e} 设备无法使用")
    device = torch.device("cpu")

print("========== 硬件自检完成 ==========")


In [None]:
"""
数据集测试
"""

# 数据集
dataset = torchvision.datasets.MNIST(
                                root="../data/datasets",
                                train=True,                                  # 使用训练集，False为测试集
                                download=True,                               # 下载数据集
                                transform=torchvision.transforms.ToTensor()  # 将图像转换为张量
                                )

batch_size = 16

# 为数据集创建数据加载器
# 为数据集创建数据加载器
dataset_loader = DataLoader(
                dataset,                   # 要加载的数据对象
                batch_size = batch_size,   # 每次迭代加载的样本数量
                shuffle=True,              # 打乱数据顺序
                num_workers = 0,
                persistent_workers = False
                )

# 从加载器中取出第一批数据
x, y = next(iter(dataset_loader))
print('Input shape:', x.shape)
print('Labels :', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap = 'gray')  # 以单通道取出所有图像，拼接成大图并用灰度显示

In [None]:
"""
添加噪声，并对输出结果进行可视化
"""

def corrupt(x, amount):
    """
    根据给定的amount值，向输入张量x中添加噪声，返回添加噪声后的张量
    """
    amount = amount.view(-1, 1, 1, 1)  # 调整amount形状以便广播
    noise = torch.rand_like(x)  # 生成与x形状相同的噪声
    return x * (1 - amount) + noise * amount

# 绘制输入数据
fig, axs = plt.subplots(2, 1, figsize=(12, 5))  # 画布行数，画布列数，画布大小。plt.subplots返回两个方法，第一个是画布对象fig，第二个是子图对象axs
plt.subplots_adjust(hspace=0.4)  # 扩大子图间距
axs[0].set_title('Input Images')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='gray')  # 以单通道取出所有图像，拼接成大图并用灰度显示

# 加入噪声
amount = torch.linspace(0, 1, x.shape[0])  # 在指定的范围内，生成一组等距离的数字，数量与x的Batch_size相同
noised_x = corrupt(x, amount)

# 绘制加入噪声后的图像
axs[1].set_title('Corrupted Images (--- amount increases --->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='gray')

In [None]:
"""
UNet2DModel 配置
"""

unet_config = {
    "sample_size": 28,  # 输入图像的大小
    "in_channels": 1,  # 输入图像的通道数
    "out_channels": 1,  # 输出图像的通道数
    "layers_per_block": 2,  # 每个块中的层数
    "block_out_channels": [32, 64, 64],  # 每个块的输出通道数
    "down_block_types": [
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ],
    "up_block_types": [
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ],
}

num_train_timesteps = 1000
net = UNet2DModel(**unet_config).to(device)


In [None]:
"""
验证输入和输出的形状是否相同，并查看 UNet2DModel 的参数量
"""

x = torch.rand(8, 1, 28, 28).to(device)
timesteps = torch.zeros(x.shape[0], dtype=torch.long, device=device)
y = net(x, timesteps).sample.shape
print(y)

sum(p.numel() for p in net.parameters())


In [None]:
"""
开始训练模型
"""

batch_size = 64

# 数据加载器
dataset_loader = DataLoader(
                dataset,                   # 要加载的数据对象
                batch_size = batch_size,   # 每次迭代加载的样本数量
                shuffle=True,              # 打乱数据顺序
                num_workers = 0,
                persistent_workers = False
                )


# 运行周期
num_epochs = 3

# 使用上文创建的 UNet2DModel

# 损失函数（均方误差）
loss_fn = nn.MSELoss()

# 优化器，根据损失函数结果调整网络权重
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3)  # 学习率：1e-3

# 记录训练损失
loss_history = []

# 训练循环
for epoch in range(num_epochs):
    for x, _ in dataset_loader:
        # 加载数据并添加噪声
        x = x.to(device)  # 加载数据
        noise_amount = torch.rand(x.shape[0], device=device)  # 为每个样本生成一个随机的噪声数量
        noisy_x = corrupt(x, noise_amount)  # 向样本中添加噪声

        # 预测的噪声结果
        timesteps = torch.zeros(x.shape[0], dtype=torch.long, device=device)
        predicted_image = net(noisy_x, timesteps).sample

        # 计算损失
        loss = loss_fn(predicted_image, x)  # 对比预测结果与原始图像

        # 反向传播并更新权重
        optimizer.zero_grad()  # 清除之前的梯度
        loss.backward()        # 反向传播计算新的梯度
        optimizer.step()       # 更新权重
        loss_history.append(loss.item())

    # 输出每个 epoch 的损失均值
    avg_loss = sum(loss_history[-len(dataset_loader):]) / len(dataset_loader)
    print(f"Finished epoch {epoch}. Average loss: {avg_loss:05f}")

# 绘制损失曲线
plt.plot(loss_history)
plt.ylim(0, 0.1)


In [None]:
"""
观察训练结果
"""

x, _ = next(iter(dataset_loader))  # 从数据集中取出一批数据
x = x[:8]  # 取出前8个样本

amount = torch.linspace(0, 1, x.shape[0])  # 生成一组等距离的噪声数量
noised_x = corrupt(x, amount)  # 向样本中添加噪声

# 得到模型预测结果
with torch.no_grad():  # 在评估模式下，不计算梯度
    timesteps = (amount.to(device) * (num_train_timesteps - 1)).long()
    predicted_image = net(noised_x.to(device), timesteps).sample.cpu()  # 预测结果，将结果移回CPU（NumPy无法绘制GPU数据）

# 绘制结果
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input Images')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='gray')
axs[1].set_title('Corrupted Images (--- amount increases --->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='gray')
axs[2].set_title('Predicted Output')
axs[2].imshow(torchvision.utils.make_grid(predicted_image)[0].clip(0, 1), cmap='gray')


In [None]:
"""
拆解采样步骤
"""

step = 5
x = torch.rand(8, 1, 28, 28).to(device)  # 随机初始化一个图像张量
step_history = [x.detach().cpu()]  # 每个步骤的图像
predicted_output = []  # 每个步骤的预测输出

for i in range(step):
    with torch.no_grad():
        scale = 1 - (i / (step - 1))
        timestep = int(scale * (num_train_timesteps - 1))
        timesteps = torch.full((x.shape[0],), timestep, dtype=torch.long, device=device)
        predicted_image = net(x, timesteps).sample  # 预测结果
        predicted_output.append(predicted_image.detach().cpu())  # 记录预测输出

        mix_factor = 1/(step - i)  # 朝预测方向移动的步骤
        x = x * (1 - mix_factor) + predicted_image * mix_factor  # 更新图像
        step_history.append(x.detach().cpu())  # 记录当前步骤的图像

# 绘制每个步骤的图像和预测输出
fig, axs = plt.subplots(step, 2, figsize=(9, 4), sharex=True)
axs[0, 0].set_title('Input Image')
axs[0, 1].set_title('Predicted Output')
for i in range(step):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='gray')
    axs[i, 1].imshow(torchvision.utils.make_grid(predicted_output[i])[0].clip(0, 1), cmap='gray')
