In [5]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
import os
import numpy as np
import matplotlib.pyplot as plt


logging.set_verbosity_error()

class CatsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_files = [os.path.join(data_dir, file) for file in os.listdir(data_dir) if file.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
#加载数据集
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = CatsDataset(data_dir='./DATA/train/cats', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [4]:
# 创建 SummaryWriter
log_dir = 'runs/DDPM_experiment'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
writer = SummaryWriter(log_dir)

# 定义简化的 UNet 模型
model = UNet2DModel(
    sample_size=64,           # 输入图像的大小
    in_channels=3,            # 输入通道数
    out_channels=3,           # 输出通道数
    layers_per_block=1,       # 每个块的层数减少为1
    block_out_channels=(64, 128, 256) # 每个块的输出通道数减少
)

# 定义噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# 定义优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 检查是否有可用的 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 定义数据集和数据加载器
transform = transforms.Compose([transforms.Resize(64), transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 训练参数
epochs = 10

# 训练过程
for epoch in range(epochs):
    epoch_loss = 0
    for step, (images, _) in enumerate(train_dataloader):
        optimizer.zero_grad()

        clean_images = images.to(device)
        noise = torch.randn(clean_images.shape).to(device)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (clean_images.shape[0],), device=device).long()

        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
        noise_pred = model(noisy_images, timesteps).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if step % 100 == 0:
            print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item()}")
            writer.add_scalar('Training Loss', loss.item(), epoch * len(train_dataloader) + step)
            
            # 生成新图像并保存
            with torch.no_grad():
                generated_images = model(noisy_images, timesteps).sample
                img_grid = torchvision.utils.make_grid(generated_images, normalize=True)
                writer.add_image('Generated Images', img_grid, global_step=epoch * len(train_dataloader) + step)

                # 保存图像
                save_image_path = f"generated_images/epoch_{epoch+1}_step_{step+1}.png"
                os.makedirs(os.path.dirname(save_image_path), exist_ok=True)
                plt.imsave(save_image_path, img_grid.permute(1, 2, 0).cpu().numpy())
                print(f"Saved generated image to {save_image_path}")

    # 记录每个 epoch 的平均损失
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    writer.add_scalar('Average Epoch Loss', avg_epoch_loss, epoch)
    print(f"Epoch {epoch+1}, Average Loss: {avg_epoch_loss}")

# 保存模型
torch.save(model.state_dict(), 'ddpm_model.pth')

# 关闭 TensorBoard
writer.close()

print("Training complete.")

Epoch 1, Step 0, Loss: 1.156358242034912


KeyboardInterrupt: 

In [None]:
pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)

pipeline.to(device)

# 生成图像
generator = torch.manual_seed(42)
images = pipeline(batch_size=10, generator=generator)["sample"]

# 显示生成的图像
plt.figure(figsize=(20, 2))
for i in range(10):
    plt.subplot(1, 10, i+1)
    img = images[i].permute(1, 2, 0)
    img = ((img + 1) / 2).clamp(0, 1)
    plt.imshow(img.cpu().numpy())
    plt.axis('off')
plt.tight_layout()
plt.show()