In [9]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# -------------------------------
# 1. 配置路径和参数
# -------------------------------
save_dir = "./denoised_results"
os.makedirs(save_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 32
T_steps = 1000  # DDPM 时间步数
# 注意：betas, alphas, alpha_bars 必须与训练一致
beta_start, beta_end = 1e-4, 0.02
betas = torch.linspace(beta_start, beta_end, T_steps).to(device)
alphas = 1. - betas
alpha_bars = torch.cumprod(alphas, dim=0)

# -------------------------------
# 2. 加载测试集
# -------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),                 # [0,255] -> [0,1]
    transforms.Lambda(lambda x: x*2-1)     # [0,1] -> [-1,1]
])

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# -------------------------------
# 3. 定义模型并加载权重
# -------------------------------
from ddpm_model import SimpleUNet  # 替换为你训练模型定义文件

model = SimpleUNet().to(device)
checkpoint_path = "./checkpoints/ddpm_step50000.pt"  # 替换为你保存的模型
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# -------------------------------
# 4. 定义去噪函数
# -------------------------------
@torch.no_grad()
def denoise(model, x_noisy, T_steps=T_steps):
    x_t = x_noisy.clone().to(device)
    for t in reversed(range(T_steps)):
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=device, dtype=torch.long)

        eps = model(x_t, t_batch)
        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]
        beta_t = betas[t]

        mean = (1/torch.sqrt(alpha_t)) * (x_t - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * eps)
        if t > 0:
            z = torch.randn_like(x_t)
            x_prev = mean + torch.sqrt(beta_t) * z
        else:
            x_prev = mean
        x_t = x_prev
    return x_t

# -------------------------------
# 5. 测试集去噪并保存对比图
# -------------------------------
for batch_idx, (x, _) in enumerate(test_loader):
    x = x.to(device)

    # 生成带噪声图像（可模拟噪声污染）
    noise = torch.randn_like(x)
    x_noisy = torch.sqrt(alpha_bars[-1]) * x + torch.sqrt(1 - alpha_bars[-1]) * noise

    # 去噪
    x_denoised = denoise(model, x_noisy)

    # 转回 [0,1] 并 CPU
    x = (x + 1)/2
    x_noisy = (x_noisy + 1)/2
    x_denoised = (x_denoised + 1)/2

    x = x.cpu()
    x_noisy = x_noisy.cpu()
    x_denoised = x_denoised.cpu()

    # 拼接原始、噪声、降噪图像
    combined = torch.cat([x, x_noisy, x_denoised], dim=0)
    grid = torchvision.utils.make_grid(combined, nrow=batch_size)
    grid = torch.clamp(grid, 0., 1.)

    # 可视化
    plt.figure(figsize=(16,6))
    plt.imshow(grid.permute(1,2,0).numpy())
    plt.axis("off")

    # 保存图片
    save_path = os.path.join(save_dir, f"denoise_batch_{batch_idx}.png")
    plt.savefig(save_path)
    plt.close()

    print(f"Saved batch {batch_idx} to {save_path}")

    # 只处理前 5 批，可修改
    if batch_idx >= 4:
        break

print("All done!")


Files already downloaded and verified


  model.load_state_dict(torch.load(checkpoint_path))


Saved batch 0 to ./denoised_results/denoise_batch_0.png
Saved batch 1 to ./denoised_results/denoise_batch_1.png
Saved batch 2 to ./denoised_results/denoise_batch_2.png
Saved batch 3 to ./denoised_results/denoise_batch_3.png
Saved batch 4 to ./denoised_results/denoise_batch_4.png
All done!
