In [None]:
"""
依赖自检
"""

import sys
import os
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

env_path = project_root / '.env'
if env_path.exists():
    for line in env_path.read_text().splitlines():
        line = line.strip()
        if not line or line.startswith('#') or '=' not in line:
            continue
        key, value = line.split('=', 1)
        value = value.strip().strip("'").strip('\"')
        os.environ.setdefault(key.strip(), value)


from diffusion.env import ensure_dependencies

ensure_dependencies()

import torch
from torch.nn import functional as f
from torchvision import transforms
from datasets import load_dataset


In [None]:
"""
Device 自检
"""

from diffusion.env import select_device

device = select_device(torch)


In [None]:
"""
加载数据集
"""
from diffusion.data import create_dataloader
from diffusion.hf import login_hf

login_hf()

dataset = load_dataset('huggan/smithsonian_butterflies_subset', split='train')

image_size = 32
batch_size = 64

# 图像预处理
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # 统一图像大小（宽x高）
        transforms.RandomHorizontalFlip(),            # 随机水平翻转图像
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),           # 归一化到 [-1, 1]
    ]
)

def transform(examples):
    """对数据集中的图像进行预处理"""
    images = [preprocess(image.convert("RGB")) for image in examples['image']]
    return {'images': images}

# 动态函数，获取数据集内容时，对数据集进行转换
dataset.set_transform(transform)

train_dataloader = create_dataloader(
    dataset,
    batch_size=batch_size
)


In [None]:
"""
可视化图像数据
"""

from PIL import Image
from diffusion.visualize import show_images

xbatch = next(iter(train_dataloader))['images'].to(device)[:8]
print(f'批量图像张量形状: {xbatch.shape}')  # torch.Size([8, 3, 32, 32])
show_images(xbatch).resize((8 * 64, 64), resample=Image.NEAREST)


In [None]:
"""
为图像添加噪声
"""

from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000
)

timesteps = torch.linspace(0, 999, 8).long().to(device)  # 8 个时间步
noise = torch.randn_like(xbatch)  # 生成随机噪声
noisy_xbatch = noise_scheduler.add_noise(xbatch, noise, timesteps)  # 添加噪声
print(f'添加噪声后的图像张量形状: {noisy_xbatch.shape}')  # torch.Size([8, 3, 32, 32])
show_images(noisy_xbatch).resize((8 * 64, 64), resample=Image.NEAREST)

In [None]:
"""
创建扩散模型
"""

from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=image_size,        # 输入图像的大小（宽和高）
    in_channels=3,                 # 输入图像的通道数（RGB 图像为 3）
    out_channels=3,                # 输出图像的通道数
    layers_per_block=2,            # 每个块中的层 ResNet 层数
    block_out_channels=(64, 128, 128, 256),  # 每个块的输出通道数
    down_block_types=(             # 下采样块类型
        "DownBlock2D", 
        "DownBlock2D", 
        "AttnDownBlock2D", 
        "AttnDownBlock2D"
    ),
    up_block_types=(               # 上采样块类型
        "AttnUpBlock2D", 
        "AttnUpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D"
    )
).to(device)


In [None]:
"""
创建训练循环
"""

import numpy as np
from matplotlib import pyplot as plt

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule="squaredcos_cap_v2"
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

losses = []

num_epochs = 50

for epoch in range(num_epochs):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch['images'].to(device)
        # 1. 生成噪声
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bsz = clean_images.shape[0]
        # 2. 为每张图像随机选择时间步
        timesteps = torch.randint(
            0,                                   # 最小时间步
            noise_scheduler.num_train_timesteps, # 最大时间步
            (bsz,),                              # 生成 bsz 个时间步
            device=clean_images.device
        ).long()
        # 3. 根据每个时间步的噪声大小，添加噪声
        noisy_images = noise_scheduler.add_noise(
            clean_images, 
            noise, 
            timesteps
        )
        # 4. 预测噪声
        noise_pred = model(
            noisy_images,
            timesteps,
            return_dict=False
        )[0]
        # 5. 计算损失
        loss = f.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())
        # 6. 优化模型参数
        optimizer.step()
        optimizer.zero_grad()
    # 每 5 个周期打印一次损失
    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
        print(f'Epoch {epoch+1}, Loss: {loss_last_epoch:.4f}')

# 绘制损失曲线
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

In [None]:
"""
使用模型生成图像
"""

from diffusers import DDPMPipeline

# 创建图像生成管线
image_pipeline = DDPMPipeline(
    unet=model,
    scheduler=noise_scheduler
)

# 生成图像
pipeline_output = image_pipeline()
pipeline_output.images[0].resize((64, 64), resample=Image.NEAREST)



In [None]:
"""
保存模型和管线
"""

image_pipeline.save_pretrained("../models/generate_butterflies")