In [None]:
# Imports
# Core PyTorch libraries for neural network operations
import torch
from torch.utils.data import DataLoader

# our DDPM components
from Noise import NoiseScheduler
from DatasetLoader import ImageDataset
from model import U_net

# Utilities
import os
from tqdm import tqdm

# Device setup
# Automatically use GPU if available, otherwise fall back to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data loading
# Set image size to 28x28
image_size = 28

# Load dataset
# 
data = ImageDataset("Path/to/images", image_size=image_size)
print(f"Dataset size: {len(data)} images")

# Create DataLoader
data_loader = DataLoader(data, batch_size=64, shuffle=True)

# Noise scheduler setup
# Initialize the noise schedule for diffusion process
num_timesteps = 1000
noise_scheduler = NoiseScheduler(
    num_timesteps=num_timesteps, 
    beta_start=0.0001,      # Very small noise at early steps
    beta_end=0.02,          # Larger noise at later steps
    device=device
)

# Model initialization
# For color images, use in_channels = 3
model = U_net(device, in_channels = 1)
model.to(device)

# Display model architecture and parameter count
print(f"Model architecture:")
print(model)

In [None]:
# Training configuration
num_epochs = 100

# Optimizer: AdamW is more stable than Adam for diffusion models
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Loss function:
criterion = torch.nn.MSELoss()

# Training loop
for epoch in range(num_epochs):
    losses = []
    
    # batch loop:
    for images in tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Move images to device
        images = images.to(device)

        # Forward diffusion process
        # Get batch size for timestep sampling
        batch_size = images.size(0)
        
        # Generate random Gaussian noise
        # Shape: same as images (batch_size, channels, height, width)
        # This is the target noise we want the model to predict
        noise = torch.randn_like(images).to(device)

        # Sample random timesteps for each image in batch
        # Range: 0 to num_timesteps-1
        timestep = torch.randint(0, num_timesteps, (images.shape[0],), device=device).long()

        # Add noise to images according to the timestep
        noisy_images = noise_scheduler.add_noise(images, noise, timestep)

        # Reverse diffusion (model training)
        # Model predicts what noise was added
        noise_pred = model(noisy_images, timestep)

        # Compute loss
        loss = criterion(noise_pred, noise)
        losses.append(loss.item())

        # Backpropagation
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        
        # Compute gradients
        loss.backward()
        
        # Update model parameters
        optimizer.step()

    # Print epoch statistics
    avg_loss = sum(losses) / len(losses) if losses else 0
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}")

    # Checkpoint saving
    # Save model after each epoch
    save_path = os.path.join("your/path", f"model_epoch_{epoch+1}.pth")
    os.makedirs("your/path", exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")