# pytorch版本

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import logging

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# UNet 模型定义
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        def double_conv(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
        
        self.down1 = double_conv(in_channels, 64)
        self.down2 = double_conv(64, 128)
        self.down3 = double_conv(128, 256)
        self.up1 = double_conv(256 + 128, 128)
        self.up2 = double_conv(128 + 64, 64)
        self.up3 = nn.Conv2d(64, out_channels, 1)
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.time_emb = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

    def forward(self, x, t):
        t_emb = self.time_emb(t.view(-1, 1)).view(-1, 64, 1, 1)
        x1 = self.down1(x)
        x2 = self.pool(x1)
        x2 = self.down2(x2)
        x3 = self.pool(x2)
        x3 = self.down3(x3)
        x = self.upsample(x3)
        x = torch.cat([x, x2], dim=1)
        x = self.up1(x)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up2(x)
        x = self.up3(x)
        return x

# 扩散模型类
class DiffusionModel:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02):
        self.T = T
        self.beta = torch.linspace(beta_start, beta_end, T).to(device)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
    
    def forward_process(self, x0, t, noise=None):
        """正向过程：添加噪声"""
        if noise is None:
            noise = torch.randn_like(x0).to(device)
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t]).view(-1, 1, 1, 1)
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return xt, noise
    
    def sample_step(self, model, xt, t):
        """反向过程单步"""
        model.eval()
        with torch.no_grad():
            beta_t = self.beta[t].view(-1, 1, 1, 1)
            alpha_t = self.alpha[t].view(-1, 1, 1, 1)
            alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1, 1)
            noise_pred = model(xt, t / self.T)
            x_prev = (xt - beta_t / torch.sqrt(1 - alpha_bar_t) * noise_pred) / torch.sqrt(alpha_t)
            if t > 0:
                x_prev += torch.sqrt(beta_t) * torch.randn_like(xt)
        return x_prev

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# 训练设置
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
diffusion = DiffusionModel(T=1000)
num_epochs = 10
os.makedirs('samples', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# 日志设置
logging.basicConfig(filename='logs/training.log', level=logging.INFO, 
                    format='%(asctime)s - %(message)s')
losses = []

# 训练循环
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch_idx, (images, _) in enumerate(tqdm(trainloader)):
        images = images.to(device)
        optimizer.zero_grad()
        
        # 随机时间步
        t = torch.randint(0, diffusion.T, (images.size(0),), device=device).float()
        xt, noise = diffusion.forward_process(images, t.long())
        
        # 预测噪声
        noise_pred = model(xt, t / diffusion.T)
        loss = nn.MSELoss()(noise_pred, noise)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(trainloader)
    losses.append(avg_loss)
    logging.info(f'Epoch {epoch+1}, Loss: {avg_loss:.6f}')
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.6f}')
    
    # 每5个 epoch 保存采样图像
    if (epoch + 1) % 5 == 0:
        model.eval()
        sample = torch.randn(16, 3, 32, 32).to(device)
        for t in reversed(range(diffusion.T)):
            sample = diffusion.sample_step(model, sample, torch.tensor([t]).to(device))
        sample = (sample.clamp(-1, 1) + 1) / 2
        grid = torchvision.utils.make_grid(sample, nrow=4)
        plt.figure(figsize=(8, 8))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.savefig(f'samples/epoch_{epoch+1}.png')
        plt.close()

# 绘制 loss 曲线
plt.figure(figsize=(10, 5))
plt.plot(losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.savefig('logs/loss_curve.png')
plt.close()

# 保存模型
torch.save(model.state_dict(), 'unet_diffusion_cifar10.pth')

Files already downloaded and verified


100%|██████████| 391/391 [00:08<00:00, 45.25it/s]


Epoch 1, Loss: 0.124706


100%|██████████| 391/391 [00:08<00:00, 44.69it/s]


Epoch 2, Loss: 0.062822


100%|██████████| 391/391 [00:08<00:00, 45.64it/s]


Epoch 3, Loss: 0.055694


100%|██████████| 391/391 [00:08<00:00, 44.75it/s]


Epoch 4, Loss: 0.052523


100%|██████████| 391/391 [00:08<00:00, 44.76it/s]


Epoch 5, Loss: 0.049559


100%|██████████| 391/391 [00:08<00:00, 45.15it/s]


Epoch 6, Loss: 0.047112


100%|██████████| 391/391 [00:08<00:00, 44.64it/s]


Epoch 7, Loss: 0.045910


100%|██████████| 391/391 [00:08<00:00, 44.38it/s]


Epoch 8, Loss: 0.044645


100%|██████████| 391/391 [00:08<00:00, 45.18it/s]


Epoch 9, Loss: 0.043990


100%|██████████| 391/391 [00:08<00:00, 45.20it/s]


Epoch 10, Loss: 0.043917


# jittor版本

In [1]:
import jittor as jt
from jittor import nn, optim, dataset
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import logging
import jittor.transform as transforms
# 设置设备
jt.flags.use_cuda = 1  # jt.has_cuda
jt.flags.amp_reg = 0  # 禁用自动混合精度（若未使用AMP）

# UNet 模型定义
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()

        def double_conv(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm(out_ch),
                nn.ReLU(),  # CBL
                nn.Conv(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm(out_ch),
                nn.ReLU()
            )

        self.down1 = double_conv(in_channels, 64)
        self.down2 = double_conv(64, 128)
        self.down3 = double_conv(128, 256)
        self.up1 = double_conv(256 + 128, 128)
        self.up2 = double_conv(128 + 64, 64)
        self.up3 = nn.Conv(64, out_channels, 1)
        self.pool = nn.Pool(2, op='maximum')
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.time_emb = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )

    def execute(self, x, t):
        t_emb = self.time_emb(t.reshape(-1, 1)).reshape(-1, 64, 1, 1)
        x1 = self.down1(x)
        x2 = self.pool(x1)
        x2 = self.down2(x2)
        x3 = self.pool(x2)
        x3 = self.down3(x3)
        x = self.upsample(x3)
        x = jt.concat([x, x2], dim=1)
        x = self.up1(x)
        x = self.upsample(x)
        x = jt.concat([x, x1], dim=1)
        x = self.up2(x)
        x = self.up3(x)
        return x


# 扩散模型类
class DiffusionModel:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02):
        self.T = T
        self.beta = jt.linspace(beta_start, beta_end, T)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = jt.cumprod(self.alpha, dim=0)

    def forward_process(self, x0, t, noise=None):
        """正向过程：添加噪声"""
        if noise is None:
            noise = jt.randn_like(x0)
        sqrt_alpha_bar = jt.sqrt(self.alpha_bar[t]).reshape(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar = jt.sqrt(1 - self.alpha_bar[t]).reshape(-1, 1, 1, 1)
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return xt, noise

    def sample_step(self, model, xt, t):
        """反向过程单步"""
        model.eval()
        with jt.no_grad():
            beta_t = self.beta[t].reshape(-1, 1, 1, 1)
            alpha_t = self.alpha[t].reshape(-1, 1, 1, 1)
            alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
            noise_pred = model(xt, t / self.T)
            x_prev = (xt - beta_t / jt.sqrt(1 - alpha_bar_t) * noise_pred) / jt.sqrt(alpha_t)
        if t > 0:
            x_prev += jt.sqrt(beta_t) * jt.randn_like(xt)
        jt.gc() #触发垃圾回收
        return x_prev


# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = dataset.CIFAR10(root='./data', train=True, download=True, transform=transform)
batch_size = 128
total_batches = (len(trainset) + batch_size - 1) // batch_size
trainloader = dataset.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# 训练设置
model = UNet()
optimizer = optim.Adam(model.parameters(), lr=2e-4)
diffusion = DiffusionModel(T=1000)
num_epochs = 10
os.makedirs('samples', exist_ok=True)
os.makedirs('logs', exist_ok=True)

# 日志设置
logging.basicConfig(filename='logs/training.log', level=logging.INFO,
                    format='%(asctime)s - %(message)s')
losses = []


# 保存图像的函数（Jittor 没有直接的 make_grid，需手动实现）
def save_image_grid(images, filename, nrow=4):
    images = (images.clamp(-1, 1) + 1) / 2
    images = images.numpy().transpose(0, 2, 3, 1)  # (B, C, H, W) -> (B, H, W, C)
    n = images.shape[0]
    ncols = nrow
    nrows = (n + ncols - 1) // ncols
    grid = np.zeros((nrows * 32, ncols * 32, 3))
    for i in range(n):
        row = i // ncols
        col = i % ncols
        grid[row * 32:(row + 1) * 32, col * 32:(col + 1) * 32] = images[i]
    plt.figure(figsize=(8, 8))
    plt.imshow(grid)
    plt.axis('off')
    plt.savefig(filename)
    plt.close()


# 训练循环
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch_idx, (images, _) in enumerate(tqdm(trainloader, total=total_batches)):
        images = jt.array(images)

        # 随机时间步
        t = jt.randint(0, diffusion.T, (images.shape[0],)).float()
        xt, noise = diffusion.forward_process(images, t.int())

        # 预测噪声
        noise_pred = model(xt, t / diffusion.T)
        loss = nn.mse_loss(noise_pred, noise)

        optimizer.step(loss)

        epoch_loss += loss.item()

    avg_loss = epoch_loss / total_batches
    losses.append(avg_loss)
    logging.info(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')
    print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')

    # 每5个 epoch 保存采样图像
    if (epoch + 1) % 5 == 0:
        model.eval()
        sample = jt.randn(16, 3, 32, 32)
        with jt.no_grad():  # 添加此行，禁用梯度
            for t in reversed(range(diffusion.T)):
                sample = diffusion.sample_step(model, sample, jt.array([t]))
                jt.gc() #主动垃圾回收
        save_image_grid(sample, f'samples/epoch_{epoch + 1}.png')

# 绘制 loss 曲线
plt.figure(figsize=(10, 5))
plt.plot(losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.savefig('logs/loss_curve.png')
plt.close()

# 保存模型
jt.save(model.state_dict(), 'unet_diffusion_cifar10.jt')

[38;5;2m[i 0502 17:33:30.463300 36 log.cc:351] Load log_sync: 1[m
[38;5;2m[i 0502 17:33:30.565032 36 compiler.py:956] Jittor(1.3.9.14) src: /root/miniconda3/lib/python3.8/site-packages/jittor[m
[38;5;2m[i 0502 17:33:30.578478 36 compiler.py:957] g++ at /usr/bin/g++(9.4.0)[m
[38;5;2m[i 0502 17:33:30.580292 36 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.9/g++9.4.0/py3.8.10/Linux-5.15.0-9xd1/IntelRXeonRPlaxda/480a/default[m
[38;5;2m[i 0502 17:33:30.593694 36 __init__.py:412] Found nvcc(11.8.89) at /usr/local/cuda/bin/nvcc.[m
[38;5;2m[i 0502 17:33:30.606103 36 __init__.py:412] Found addr2line(2.34) at /usr/bin/addr2line.[m
[38;5;2m[i 0502 17:33:30.906809 36 compiler.py:1013] cuda key:cu11.8.89_sm_89[m
[38;5;2m[i 0502 17:33:31.540729 36 __init__.py:227] Total mem: 1007.51GB, using 16 procs for compiling.[m
[38;5;2m[i 0502 17:33:32.129632 36 jit_compiler.cc:28] Load cc_path: /usr/bin/g++[m
[38;5;2m[i 0502 17:33:32.375481 36 init.cc:63] Found cuda archs: [89,][m

Files already downloaded and verified


  0%|          | 0/391 [00:00<?, ?it/s][38;5;3m[w 0502 17:33:37.558300 36 grad.cc:81] grads[42] 'time_emb.0.weight' doesn't have gradient. It will be set to zero: Var(638:1:1:1:i0:o0:s1:n1:g1,float32,time_emb.0.weight,7ff0ef6fdc00)[64,1,][m
[38;5;3m[w 0502 17:33:37.561556 36 grad.cc:81] grads[43] 'time_emb.0.bias' doesn't have gradient. It will be set to zero: Var(657:1:1:1:i0:o0:s1:n0:g1,float32,time_emb.0.bias,7ff0ef6fde00)[64,][m
[38;5;3m[w 0502 17:33:37.562850 36 grad.cc:81] grads[44] 'time_emb.2.weight' doesn't have gradient. It will be set to zero: Var(676:1:1:1:i0:o0:s1:n1:g1,float32,time_emb.2.weight,7ff220e52800)[64,64,][m
[38;5;3m[w 0502 17:33:37.563833 36 grad.cc:81] grads[45] 'time_emb.2.bias' doesn't have gradient. It will be set to zero: Var(695:1:1:1:i0:o0:s1:n0:g1,float32,time_emb.2.bias,7ff0ef6fe000)[64,][m
100%|██████████| 391/391 [00:12<00:00, 32.14it/s]


Epoch 1, Loss: 0.125921


100%|██████████| 391/391 [00:10<00:00, 38.13it/s]


Epoch 2, Loss: 0.062329


100%|██████████| 391/391 [00:10<00:00, 38.61it/s]


Epoch 3, Loss: 0.055939


100%|██████████| 391/391 [00:10<00:00, 38.69it/s]


Epoch 4, Loss: 0.051543


100%|██████████| 391/391 [00:10<00:00, 38.65it/s]


Epoch 5, Loss: 0.049637


100%|██████████| 391/391 [00:10<00:00, 38.97it/s]


Epoch 6, Loss: 0.048435


100%|██████████| 391/391 [00:10<00:00, 39.03it/s]


Epoch 7, Loss: 0.046073


100%|██████████| 391/391 [00:10<00:00, 38.74it/s]


Epoch 8, Loss: 0.044620


100%|██████████| 391/391 [00:10<00:00, 38.94it/s]


Epoch 9, Loss: 0.044917


100%|██████████| 391/391 [00:10<00:00, 38.50it/s]


Epoch 10, Loss: 0.042958


# DDPM创新版本

In [1]:
import jittor as jt
from jittor import nn, optim, dataset
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import logging
import jittor.transform as transforms
from jittor.models import vgg16

# 环境配置
jt.flags.use_cuda = 1  # 启用CUDA
jt.flags.amp_reg = 0  # 禁用自动混合精度
print("Jittor CUDA状态:", jt.flags.use_cuda)

# 创建必要目录
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("samples", exist_ok=True)
os.makedirs("logs", exist_ok=True)

# 日志配置
logging.basicConfig(
    filename='logs/training.log',
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)


# ------------------------ 模型定义 ------------------------#
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_classes=10):
        super(UNet, self).__init__()

        # 时间+类别嵌入
        self.time_emb = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )
        self.label_emb = nn.Embedding(num_classes, 64)  # 条件嵌入

        # 下采样层
        def double_conv(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm(out_ch),
                nn.ReLU(),
                nn.Conv(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm(out_ch),
                nn.ReLU()
            )

        self.down1 = double_conv(in_channels, 64)
        self.down2 = double_conv(64, 128)
        self.down3 = double_conv(128, 256)

        # 上采样层
        self.up1 = double_conv(256 + 128, 128)
        self.up2 = double_conv(128 + 64, 64)
        self.up3 = nn.Conv(64, out_channels, 1)

        # 辅助层
        self.pool = nn.Pool(2, op='maximum')
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

    def execute(self, x, t, labels=None):
        # 嵌入处理
        t_emb = self.time_emb(t.reshape(-1, 1)).reshape(-1, 64, 1, 1)
        if labels is not None:
            label_emb = self.label_emb(labels).reshape(-1, 64, 1, 1)
            t_emb += label_emb  # 融合条件信息

        # 编码器
        x1 = self.down1(x) + t_emb  # [B,64,H,W]
        x2 = self.pool(x1)
        x2 = self.down2(x2)  # [B,128,H/2,W/2]
        x3 = self.pool(x2)
        x3 = self.down3(x3)  # [B,256,H/4,W/4]

        # 解码器
        x = self.upsample(x3)
        x = jt.concat([x, x2], dim=1)
        x = self.up1(x)
        x = self.upsample(x)
        x = jt.concat([x, x1], dim=1)
        x = self.up2(x)
        return self.up3(x)


# ----------------------- 扩散模型核心 -----------------------#
class DiffusionModel:
    def __init__(self, T=1000, beta_start=1e-4, beta_end=0.02):
        self.T_initial = T
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.current_T = T  # 渐进式采样
        self.loss_sum = jt.zeros(T)
        self.loss_count = jt.zeros(T)
        self.update_beta_schedule()

    def update_beta_schedule(self):
        """更新alpha和beta序列"""
        self.beta = jt.linspace(self.beta_start, self.beta_end, self.current_T)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = jt.cumprod(self.alpha, dim=0)

    def adjust_beta(self):
        avg_loss = (self.loss_sum + 1e-8) / (self.loss_count + 1e-8)
        weights = avg_loss / avg_loss.sum()

        # 动态调整current_T和beta
        self.current_T = max(100, int(self.current_T * 0.9))
        self.beta = self.beta_start + (self.beta_end - self.beta_start) * weights.cumsum()[:self.current_T]

        self.update_beta_schedule()
        self.loss_sum = jt.zeros_like(self.loss_sum)
        self.loss_count = jt.zeros_like(self.loss_count)

    def forward_process(self, x0, t, noise=None):
        """前向扩散过程"""
        if noise is None:
            noise = jt.randn_like(x0)
        sqrt_alpha_bar = jt.sqrt(self.alpha_bar[t]).reshape(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar = jt.sqrt(1 - self.alpha_bar[t]).reshape(-1, 1, 1, 1)
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return xt, noise

    def sample_step(self, model, xt, t, labels=None):
        """单步采样"""
        model.eval()
        with jt.no_grad():
            beta_t = self.beta[t].reshape(-1, 1, 1, 1)
            alpha_t = self.alpha[t].reshape(-1, 1, 1, 1)
            alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)

            # 条件生成支持
            noise_pred = model(xt, t / self.T_initial, labels=labels)

            # 逆向过程计算
            x_prev = (xt - beta_t / jt.sqrt(1 - alpha_bar_t) * noise_pred) / jt.sqrt(alpha_t)
            if t > 0:
                x_prev += jt.sqrt(beta_t) * jt.randn_like(xt)
            return x_prev


# --------------------- 感知损失模块 ---------------------#
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg16(pretrained=True).features
        self.slice = nn.Sequential()
        for i in range(9):  # 使用前4个卷积层
            self.slice.add_module(str(i), vgg[i])
        self.eval()

    def execute(self, pred, target):
        pred = (pred + 1) / 2  # 归一化到[0,1]
        target = (target + 1) / 2
        pred = nn.interpolate(pred, (224, 224), mode='bilinear')
        target = nn.interpolate(target, (224, 224), mode='bilinear')
        return nn.mse_loss(self.slice(pred), self.slice(target))


# --------------------- 数据加载 ----------------------#
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.ImageNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = dataset.CIFAR10(root='./data', train=True, download=True, transform=transform)
batch_size = 128
trainloader = dataset.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
total_batches = (len(trainset) + batch_size - 1) // batch_size

# --------------------- 训练初始化 ---------------------#
model = UNet(num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
diffusion = DiffusionModel(T=1000)
vgg_loss = VGGPerceptualLoss()


# --------------------- 条件采样函数 ----------------------#
def conditional_sample(model, diffusion, labels, num_samples=16):
    x = jt.randn(num_samples, 3, 32, 32)
    indices = list(range(diffusion.current_T))[::-1]

    for t in tqdm(indices, desc="Sampling"):
        x = diffusion.sample_step(model, x, jt.array([t]), labels=labels)
        x = x.clamp(-1, 1)
    return x


# --------------------- 辅助函数 ----------------------#
def save_image_grid(images, filename, nrow=4):
    """ 保存图像网格 """
    n = images.shape[0]
    ncols = min(nrow, n)
    nrows = (n + ncols - 1) // ncols

    grid = np.zeros((nrows * 32, ncols * 32, 3))
    for i in range(n):
        row = i // ncols
        col = i % ncols
        grid[row * 32:(row + 1) * 32, col * 32:(col + 1) * 32] = images[i]

    plt.figure(figsize=(8, 8))
    plt.imshow(grid)
    plt.axis('off')
    plt.savefig(filename, bbox_inches='tight')
    plt.close()


# --------------------- 训练循环 ----------------------#
num_epochs = 10
loss_history = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    with tqdm(total=total_batches, desc=f"Epoch {epoch + 1}/{num_epochs}", unit='batch') as pbar:
        for batch_idx, (images, labels) in enumerate(trainloader):
            images = jt.array(images)
            labels = jt.array(labels)

            # 扩散过程
            t = jt.randint(0, diffusion.current_T, (images.shape[0],)).float()
            xt, noise = diffusion.forward_process(images, t.int())

            # 前向计算
            noise_pred = model(xt, t / diffusion.T_initial, labels)
            loss_mse = nn.mse_loss(noise_pred, noise)
            loss_perceptual = vgg_loss(noise_pred, noise)
            loss = loss_mse + 0.05 * loss_perceptual

            # 反向传播
            optimizer.step(loss)
            epoch_loss += loss.item()

            # 更新进度条
            pbar.set_postfix({
                "Batch Loss": f"{loss.item():.4f}",
                "Avg Loss": f"{epoch_loss / (batch_idx + 1):.4f}"
            })
            pbar.update(1)

            # 记录时间步损失
            loss_per_sample = (noise_pred - noise).sqr().mean(dims=[1, 2, 3])
            for ti, li in zip(t.numpy(), loss_per_sample.numpy()):
                diffusion.loss_sum[ti] += li
                diffusion.loss_count[ti] += 1

            # 内存管理
            if batch_idx % 100 == 0:
                jt.gc()
                jt.sync_all()

    # 记录epoch损失
    avg_loss = epoch_loss / total_batches
    loss_history.append(avg_loss)
    logging.info(f"Epoch {epoch + 1} Average Loss: {avg_loss:.4f}")
    print(f"\nEpoch {epoch + 1} 平均损失: {avg_loss:.4f}\n")

    # 动态调整beta计划
    if (epoch + 1) % 5 == 0:
        diffusion.adjust_beta()
        print(f"已调整beta计划，当前T={diffusion.current_T}")

    # 保存模型和生成样本
    if epoch == num_epochs - 1:  # 最后一个epoch保存
        save_path = f"unet_diffusion_cifar10.jt"  # 最外层路径
        jt.save(model.state_dict(), save_path)
        print(f"已保存模型至 {save_path}")

        # 条件采样示例
        sample_labels = jt.randint(0, 10, (16,))
        samples = conditional_sample(model, diffusion, sample_labels, num_samples=16)
        samples = (samples.clamp(-1, 1) + 1) / 2  # 转换为[0,1]
        samples_np = samples.numpy().transpose(0, 2, 3, 1)  # (B,C,H,W) -> (B,H,W,C)
        save_image_grid(samples_np, f"samples/epoch_{epoch + 1}.png")

# 绘制训练曲线
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig('logs/training_curve.png')
plt.close()

[38;5;2m[i 0503 17:13:38.702224 60 log.cc:351] Load log_sync: 1[m
[38;5;2m[i 0503 17:13:38.761555 60 compiler.py:956] Jittor(1.3.9.14) src: /root/miniconda3/lib/python3.8/site-packages/jittor[m
[38;5;2m[i 0503 17:13:38.771698 60 compiler.py:957] g++ at /usr/bin/g++(9.4.0)[m
[38;5;2m[i 0503 17:13:38.772799 60 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.9/g++9.4.0/py3.8.10/Linux-5.15.0-9xc6/IntelRXeonRGolx95/480a/default[m
[38;5;2m[i 0503 17:13:38.782032 60 __init__.py:412] Found nvcc(11.8.89) at /usr/local/cuda/bin/nvcc.[m
[38;5;2m[i 0503 17:13:38.791479 60 __init__.py:412] Found addr2line(2.34) at /usr/bin/addr2line.[m
[38;5;2m[i 0503 17:13:39.073639 60 compiler.py:1013] cuda key:cu11.8.89_sm_89[m
[38;5;2m[i 0503 17:13:39.684431 60 __init__.py:227] Total mem: 1007.52GB, using 16 procs for compiling.[m
[38;5;2m[i 0503 17:13:39.898776 60 jit_compiler.cc:28] Load cc_path: /usr/bin/g++[m
[38;5;2m[i 0503 17:13:40.214279 60 init.cc:63] Found cuda archs: [89,][m

Jittor CUDA状态: 1
Files already downloaded and verified


Epoch 1/10: 100%|██████████| 391/391 [01:11<00:00,  5.49batch/s, Batch Loss=0.0899, Avg Loss=0.1719]



Epoch 1 平均损失: 0.1719



Epoch 2/10: 100%|██████████| 391/391 [01:06<00:00,  5.85batch/s, Batch Loss=0.1068, Avg Loss=0.0830]



Epoch 2 平均损失: 0.0830



Epoch 3/10: 100%|██████████| 391/391 [01:06<00:00,  5.86batch/s, Batch Loss=0.0698, Avg Loss=0.0720]



Epoch 3 平均损失: 0.0720



Epoch 4/10: 100%|██████████| 391/391 [01:06<00:00,  5.86batch/s, Batch Loss=0.0833, Avg Loss=0.0664]



Epoch 4 平均损失: 0.0664



Epoch 5/10: 100%|██████████| 391/391 [01:06<00:00,  5.86batch/s, Batch Loss=0.0507, Avg Loss=0.0627]



Epoch 5 平均损失: 0.0627

已调整beta计划，当前T=900


Epoch 6/10: 100%|██████████| 391/391 [01:06<00:00,  5.89batch/s, Batch Loss=0.0385, Avg Loss=0.0628]



Epoch 6 平均损失: 0.0628



Epoch 7/10: 100%|██████████| 391/391 [01:06<00:00,  5.85batch/s, Batch Loss=0.0734, Avg Loss=0.0601]



Epoch 7 平均损失: 0.0601



Epoch 8/10: 100%|██████████| 391/391 [01:07<00:00,  5.83batch/s, Batch Loss=0.0654, Avg Loss=0.0575]



Epoch 8 平均损失: 0.0575



Epoch 9/10: 100%|██████████| 391/391 [01:06<00:00,  5.85batch/s, Batch Loss=0.0495, Avg Loss=0.0565]



Epoch 9 平均损失: 0.0565



Epoch 10/10: 100%|██████████| 391/391 [01:06<00:00,  5.86batch/s, Batch Loss=0.0916, Avg Loss=0.0550]



Epoch 10 平均损失: 0.0550

已调整beta计划，当前T=810
已保存模型至 unet_diffusion_cifar10.jt


Sampling: 100%|██████████| 810/810 [00:02<00:00, 311.56it/s]
