In [1]:
import os
import math
import time
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from pathlib import Path
from datetime import timedelta
import csv

In [2]:
# Configs
DATA_DIR = './data'
CHECKPOINT_DIR = './checkpoints_diffusion'
LOG_PATH = './logs/diffusion_log.csv'
IMAGE_SIZE = 64
BATCH_SIZE = 8
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TIMESTEPS = 1000  # noise steps

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)

In [3]:
# Dataset with noisy-clean pairs using random Gaussian noise
class FaceDataset(Dataset):
    def __init__(self, root_dir, image_size=64):
        self.image_paths = list(Path(root_dir).rglob("*.jpg")) + list(Path(root_dir).rglob("*.png"))
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        x_start = self.transform(img)
        return x_start

In [4]:
# Positional encoding for timestep embedding
def sinusoidal_embedding(n, d):
    pe = torch.zeros(n, d)
    position = torch.arange(0, n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(math.log(10000.0) / d))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [5]:
# Basic UNet block for noise prediction
class DenoiseModel(nn.Module):
    def __init__(self, img_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim)
        )
        self.conv1 = nn.Conv2d(img_channels, base_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(base_channels, base_channels, 3, padding=1)
        self.conv3 = nn.Conv2d(base_channels, img_channels, 3, padding=1)
        self.act = nn.ReLU()
        self.time_to_channel = nn.Linear(time_dim, base_channels)

    def forward(self, x, t_emb):
        t = self.time_embed(t_emb)
        t = self.time_to_channel(t).unsqueeze(2).unsqueeze(3)
        x = self.act(self.conv1(x) + t)
        x = self.act(self.conv2(x))
        return self.conv3(x)

In [9]:
# Scheduler and noise methods
class Diffusion:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        self.betas = torch.linspace(1e-4, 0.02, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha_hat = self.alpha_hat[t].sqrt().view(-1, 1, 1, 1).to(DEVICE)
        sqrt_one_minus = (1 - self.alpha_hat[t]).sqrt().view(-1, 1, 1, 1).to(DEVICE)
        return sqrt_alpha_hat * x_start + sqrt_one_minus * noise

    def p_losses(self, model, x_start, t):
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        t_emb = sinusoidal_embedding(t, 256).to(x_start.device)
        predicted_noise = model(x_noisy, t_emb)
        return nn.MSELoss()(predicted_noise, noise)

In [10]:
# PSNR and SSIM for evaluation
def psnr(img1, img2):
    mse = nn.functional.mse_loss(img1, img2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def ssim(img1, img2):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    mu1 = nn.functional.avg_pool2d(img1, 3, 1, 1)
    mu2 = nn.functional.avg_pool2d(img2, 3, 1, 1)
    sigma1 = nn.functional.avg_pool2d(img1 * img1, 3, 1, 1) - mu1 ** 2
    sigma2 = nn.functional.avg_pool2d(img2 * img2, 3, 1, 1) - mu2 ** 2
    sigma12 = nn.functional.avg_pool2d(img1 * img2, 3, 1, 1) - mu1 * mu2
    ssim_map = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1 ** 2 + mu2 ** 2 + C1) * (sigma1 + sigma2 + C2))
    return ssim_map.mean()

In [11]:
# Train loop
def train():
    dataset = FaceDataset(DATA_DIR, IMAGE_SIZE)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    model = DenoiseModel().to(DEVICE)
    diffusion = Diffusion(timesteps=TIMESTEPS)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    if not os.path.exists(LOG_PATH):
        with open(LOG_PATH, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["Epoch", "Loss", "PSNR", "SSIM", "Time"])

    for epoch in range(EPOCHS):
        model.train()
        losses = []
        psnrs, ssims = [], []
        start_time = time.time()
        for x in dataloader:
            x = x.to(DEVICE)
            t = torch.randint(0, TIMESTEPS, (x.size(0),), device=DEVICE).long()
            loss = diffusion.p_losses(model, x, t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        avg_loss = np.mean(losses)
        elapsed = timedelta(seconds=int(time.time() - start_time))
        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | Time: {elapsed}")

        torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/model_epoch_{epoch+1}.pth")

        with open(LOG_PATH, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, avg_loss, "-", "-", str(elapsed)])

if __name__ == "__main__":
    train()

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)