In [1]:
import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from diffusers import UNet2DModel, DDPMScheduler
import imageio
import torchvision.utils as vutils
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_PATH = "img_align_celeba"   

SAVE_DIR = "checkpoints"
SAMPLE_DIR = "samples"

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)

In [3]:
class CelebADataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = sorted(os.listdir(root_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image

In [4]:
image_size = 64
batch_size = 256

transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize((64,64)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])


dataset = CelebADataset(DATASET_PATH, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

print("Total images:", len(dataset))

Total images: 202599


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

model = UNet2DModel(
    sample_size=64,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 128, 128, 256),
    down_block_types=("DownBlock2D","DownBlock2D","AttnDownBlock2D","DownBlock2D"),
    up_block_types=("UpBlock2D","AttnUpBlock2D","UpBlock2D","UpBlock2D"),
).to(device)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

cuda


In [6]:
torch.manual_seed(42)
fixed_noise = torch.randn((9, 3, 64, 64)).to(device)

In [7]:
@torch.no_grad()
def sample_images(epoch):

    model.eval()

    sample = fixed_noise.clone()

    # reverse diffusion
    for t in reversed(range(1000)):

        t_tensor = torch.full((9,), t, device=device, dtype=torch.long)

        noise_pred = model(sample, t_tensor).sample
        sample = noise_scheduler.step(noise_pred, t, sample).prev_sample

    # [-1,1] -> [0,1]
    imgs = (sample.clamp(-1, 1) + 1) / 2
    imgs = imgs.cpu()

    # 3x3 GRID
    grid = vutils.make_grid(imgs, nrow=3, padding=2)
    grid = grid.permute(1, 2, 0).numpy()

    plt.figure(figsize=(6,6))
    plt.axis("off")
    plt.imshow(grid)
    plt.savefig(f"{SAMPLE_DIR}/epoch_{epoch:03d}.png", bbox_inches='tight', pad_inches=0)
    plt.close()

    print(f"Saved samples for epoch {epoch}")

In [None]:
epochs = 60

for epoch in range(epochs):

    model.train()
    pbar = tqdm(dataloader)

    for batch in pbar:

        # Load clean images
        clean_images = batch.to(device, non_blocking=True)

        # Sample noise
        noise = torch.randn_like(clean_images)

        # Random timestep for each image
        timesteps = torch.randint(
            0, 1000, (clean_images.shape[0],), device=device
        ).long()

    
        # Forward diffusion (add noise)
        noisy_images = noise_scheduler.add_noise(
            clean_images, noise, timesteps
        )

        noise_pred = model(noisy_images, timesteps).sample

        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch} | Loss: {loss.item():.4f}")

    model.save_pretrained(f"{SAVE_DIR}/epoch_{epoch}")

    sample_images(epoch)

print("Training Finished!")

Epoch 0 | Loss: 0.0310: 100%|██████████| 792/792 [05:57<00:00,  2.21it/s]


Saved samples for epoch 0


Epoch 1 | Loss: 0.0168: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 1


Epoch 2 | Loss: 0.0162: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 2


Epoch 3 | Loss: 0.0153: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 3


Epoch 4 | Loss: 0.0211: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 4


Epoch 5 | Loss: 0.0161: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 5


Epoch 6 | Loss: 0.0182: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 6


Epoch 7 | Loss: 0.0148: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 7


Epoch 8 | Loss: 0.0228: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 8


Epoch 9 | Loss: 0.0191: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 9


Epoch 10 | Loss: 0.0204: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 10


Epoch 11 | Loss: 0.0176: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 11


Epoch 12 | Loss: 0.0191: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 12


Epoch 13 | Loss: 0.0200: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 13


Epoch 14 | Loss: 0.0169: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 14


Epoch 15 | Loss: 0.0236: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 15


Epoch 16 | Loss: 0.0197: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 16


Epoch 17 | Loss: 0.0161: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 17


Epoch 18 | Loss: 0.0189: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 18


Epoch 19 | Loss: 0.0164: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 19


Epoch 20 | Loss: 0.0215: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 20


Epoch 21 | Loss: 0.0146: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 21


Epoch 22 | Loss: 0.0193: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 22


Epoch 23 | Loss: 0.0135: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 23


Epoch 24 | Loss: 0.0145: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 24


Epoch 25 | Loss: 0.0151: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 25


Epoch 26 | Loss: 0.0167: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 26


Epoch 27 | Loss: 0.0166: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 27


Epoch 28 | Loss: 0.0193: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 28


Epoch 29 | Loss: 0.0166: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 29


Epoch 30 | Loss: 0.0241: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 30


Epoch 31 | Loss: 0.0135: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 31


Epoch 32 | Loss: 0.0124: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 32


Epoch 33 | Loss: 0.0169: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 33


Epoch 34 | Loss: 0.0180: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 34


Epoch 35 | Loss: 0.0192: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 35


Epoch 36 | Loss: 0.0154: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 36


Epoch 37 | Loss: 0.0187: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 37


Epoch 38 | Loss: 0.0161: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 38


Epoch 39 | Loss: 0.0153: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 39


Epoch 40 | Loss: 0.0156: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 40


Epoch 41 | Loss: 0.0131: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 41


Epoch 42 | Loss: 0.0164: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 42


Epoch 43 | Loss: 0.0161: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 43


Epoch 44 | Loss: 0.0255: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 44


Epoch 45 | Loss: 0.0193: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 45


Epoch 46 | Loss: 0.0153: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 46


Epoch 47 | Loss: 0.0144: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 47


Epoch 48 | Loss: 0.0143: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 48


Epoch 49 | Loss: 0.0212: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 49


Epoch 50 | Loss: 0.0131: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 50


Epoch 51 | Loss: 0.0114: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 51


Epoch 52 | Loss: 0.0131: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 52


Epoch 53 | Loss: 0.0107: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 53


Epoch 54 | Loss: 0.0099: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 54


Epoch 55 | Loss: 0.0138: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 55


Epoch 56 | Loss: 0.0162: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 56


Epoch 57 | Loss: 0.0158: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 57


Epoch 58 | Loss: 0.0149: 100%|██████████| 792/792 [05:56<00:00,  2.22it/s]


Saved samples for epoch 58


Epoch 59 | Loss: 0.0177: 100%|██████████| 792/792 [05:57<00:00,  2.22it/s]


Saved samples for epoch 59
Training Finished!
