In [1]:
# Handle potential OpenMP library duplication issues, common in some environments
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

# Core PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F # Often imported as F for functional API

# Mathematical operations
import math
import numpy as np

# Torchvision for datasets, transforms, and utilities
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset # Specific import for custom datasets

# Image handling and plotting
from PIL import Image # For image manipulation, often used with custom datasets
import matplotlib.pyplot as plt

In [2]:
# Automatically select GPU if available, otherwise fall back to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # Confirm which device is being used

Using device: cuda:0


In [3]:
#U-Net architecture for noise prediction in diffusion models with built-in residual connections
#1.  nn.Conv2d(in_channels, 64, 3, padding=1), This is for the UNet; first term is no. of feature maps and the next is the time embeddings
#2.  nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False). This is for regular CNN architecture


import torch
import torch.nn as nn
import math

class UNet(nn.Module):
    """U-Net architecture for noise prediction in diffusion models with built-in residual connections, optimized for 128x128 RGB images"""
    def __init__(self, input_channels=3, time_embedding_dim=256):
        super().__init__() # Initialize the base PyTorch module.

        # --- Shared Utility Layers ---
        self.downsample_pool = nn.MaxPool2d(2) # Define max pooling for spatial downsampling.
        self.upscale_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # Set up bilinear upsampling.
        self.activation_fn = nn.ReLU() # Use ReLU as the activation function.

        # --- Encoder Path Layers (Downsampling) ---
        # Encoder Block 1 (handles 128x128 resolution)
        self.encoder1_conv_a = nn.Conv2d(input_channels, 64, 3, padding=1) # First convolution for input features.
        self.encoder1_norm_a = nn.BatchNorm2d(64) # Batch normalization after first conv.
        self.encoder1_conv_b = nn.Conv2d(64, 64, 3, padding=1) # Second convolution for feature refinement.
        self.encoder1_norm_b = nn.BatchNorm2d(64) # Batch normalization after second conv.

        # Encoder Block 2 (operates on 64x64 resolution)
        self.encoder2_conv_a = nn.Conv2d(64, 128, 3, padding=1) # First convolution in this block.
        self.encoder2_norm_a = nn.BatchNorm2d(128) # Batch normalization.
        self.encoder2_conv_b = nn.Conv2d(128, 128, 3, padding=1) # Second convolution.
        self.encoder2_norm_b = nn.BatchNorm2d(128) # Batch normalization.

        # Encoder Block 3 (processes 32x32 resolution at bottleneck)
        self.encoder3_conv_a = nn.Conv2d(128, 256, 3, padding=1) # First convolution for deepest features.
        self.encoder3_norm_a = nn.BatchNorm2d(256) # Batch normalization.
        self.encoder3_conv_b = nn.Conv2d(256, 256, 3, padding=1) # Second convolution.
        self.encoder3_norm_b = nn.BatchNorm2d(256) # Batch normalization.

        # --- Decoder Path Layers (Upsampling) ---
        # Decoder Block 3 (upsamples to 64x64, integrates encoder 2 features)
        self.decoder3_conv_a = nn.Conv2d(384, 128, 3, padding=1) # First convolution, input accounts for skip connection.
        self.decoder3_norm_a = nn.BatchNorm2d(128) # Batch normalization.
        self.decoder3_conv_b = nn.Conv2d(128, 128, 3, padding=1) # Second convolution.
        self.decoder3_norm_b = nn.BatchNorm2d(128) # Batch normalization.

        # Decoder Block 2 (upsamples to 128x128, integrates encoder 1 features)
        self.decoder2_conv_a = nn.Conv2d(192, 64, 3, padding=1) # First convolution, input accounts for skip connection.
        self.decoder2_norm_a = nn.BatchNorm2d(64) # Batch normalization.
        self.decoder2_conv_b = nn.Conv2d(64, 64, 3, padding=1) # Second convolution.
        self.decoder2_norm_b = nn.BatchNorm2d(64) # Batch normalization.

        # --- Output Layer ---
        self.output_conv = nn.Conv2d(64, input_channels, kernel_size=1) # Final convolution to match original channels.

        # --- Time Embedding Components ---
        self.time_embedding_dim = time_embedding_dim # Store the dimension for time embeddings.

        # Define MLPs for projecting time embeddings to each encoder/decoder block's channel size.
        self.time_proj_enc1 = nn.Sequential(nn.Linear(time_embedding_dim, 64), nn.SiLU(), nn.Linear(64, 64)) # MLP for encoder block 1.
        self.time_proj_enc2 = nn.Sequential(nn.Linear(time_embedding_dim, 128), nn.SiLU(), nn.Linear(128, 128)) # MLP for encoder block 2.
        self.time_proj_enc3 = nn.Sequential(nn.Linear(time_embedding_dim, 256), nn.SiLU(), nn.Linear(256, 256)) # MLP for encoder block 3.
        self.time_proj_dec3 = nn.Sequential(nn.Linear(time_embedding_dim, 128), nn.SiLU(), nn.Linear(128, 128)) # MLP for decoder block 3.
        self.time_proj_dec2 = nn.Sequential(nn.Linear(time_embedding_dim, 64), nn.SiLU(), nn.Linear(64, 64)) # MLP for decoder block 2.

    def get_time_embedding(self, time_step):
        """Generate sinusoidal time embedding and project through MLPs for each block"""
        half_embedding_size = self.time_embedding_dim // 2 # Calculate half dimension for sin/cos embeddings.
        embedding_indices = torch.arange(half_embedding_size, device=time_step.device).float() # Create position indices.
        frequency_scales = torch.exp(-math.log(10000) * embedding_indices / half_embedding_size) # Compute frequency bands.
        scaled_time = time_step * frequency_scales.unsqueeze(0) # Apply time step to frequency scales.
        full_embeddings = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) # Concatenate sin and cos for full embedding.

        # Project time embedding through specific MLPs and reshape for broadcasting.
        emb_e1 = self.time_proj_enc1(full_embeddings).unsqueeze(-1).unsqueeze(-1) # For encoder block 1.
        emb_e2 = self.time_proj_enc2(full_embeddings).unsqueeze(-1).unsqueeze(-1) # For encoder block 2.
        emb_e3 = self.time_proj_enc3(full_embeddings).unsqueeze(-1).unsqueeze(-1) # For encoder block 3.
        emb_d3 = self.time_proj_dec3(full_embeddings).unsqueeze(-1).unsqueeze(-1) # For decoder block 3.
        emb_d2 = self.time_proj_dec2(full_embeddings).unsqueeze(-1).unsqueeze(-1) # For decoder block 2.

        return {
            'enc1': emb_e1, # Return time embedding for encoder block 1.
            'enc2': emb_e2, # Return time embedding for encoder block 2.
            'enc3': emb_e3, # Return time embedding for encoder block 3.
            'dec3': emb_d3, # Return time embedding for decoder block 3.
            'dec2': emb_d2 # Return time embedding for decoder block 2.
        }

    def forward(self, input_image, timestep):
        """Forward pass through U-Net optimized for 128x128 RGB input with time embeddings at each block"""
        timestep_input = timestep.unsqueeze(-1).float() # Prepare timestep tensor for embedding.
        block_time_embeddings = self.get_time_embedding(timestep_input) # Get specific time embeddings for all blocks.

        # --- Encoder Path Execution ---
        # Encoder Block 1 (initial 128x128 resolution)
        enc_out_1 = self.activation_fn(self.encoder1_norm_a(self.encoder1_conv_a(input_image))) # First conv, batch norm, and activation.
        enc_out_1 = self.activation_fn(self.encoder1_norm_b(self.encoder1_conv_b(enc_out_1))) # Second conv, batch norm, and activation.
        enc_out_1 = enc_out_1 + block_time_embeddings['enc1'] # Add time embedding to features.

        # Encoder Block 2 (after first downsampling to 64x64)
        pooled_enc_1 = self.downsample_pool(enc_out_1) # Apply max pooling.
        enc_out_2 = self.activation_fn(self.encoder2_norm_a(self.encoder2_conv_a(pooled_enc_1))) # First conv, batch norm, activation.
        enc_out_2 = self.activation_fn(self.encoder2_norm_b(self.encoder2_conv_b(enc_out_2))) # Second conv, batch norm, activation.
        enc_out_2 = enc_out_2 + block_time_embeddings['enc2'] # Add time embedding.

        # Encoder Block 3 (deepest features after second downsampling to 32x32)
        pooled_enc_2 = self.downsample_pool(enc_out_2) # Apply max pooling.
        enc_out_3 = self.activation_fn(self.encoder3_norm_a(self.encoder3_conv_a(pooled_enc_2))) # First conv, batch norm, activation.
        enc_out_3 = self.activation_fn(self.encoder3_norm_b(self.encoder3_conv_b(enc_out_3))) # Second conv, batch norm, activation.
        enc_out_3 = enc_out_3 + block_time_embeddings['enc3'] # Add time embedding.

        # --- Decoder Path Execution (with Skip Connections) ---
        # Decoder Block 3 (upsampling from 32x32 to 64x64)
        upsampled_dec_3 = self.upscale_layer(enc_out_3) # Upsample features.
        decoder_input_3 = torch.cat([upsampled_dec_3, enc_out_2], dim=1) # Concatenate with skip connection from encoder 2.
        dec_out_3 = self.activation_fn(self.decoder3_norm_a(self.decoder3_conv_a(decoder_input_3))) # First conv, batch norm, activation.
        dec_out_3 = self.decoder3_norm_b(self.decoder3_conv_b(dec_out_3)) # Second conv and batch norm.
        dec_out_3 = dec_out_3 + block_time_embeddings['dec3'] # Add time embedding.

        # Decoder Block 2 (upsampling from 64x64 to 128x128)
        upsampled_dec_2 = self.upscale_layer(dec_out_3) # Upsample features.
        decoder_input_2 = torch.cat([upsampled_dec_2, enc_out_1], dim=1) # Concatenate with skip connection from encoder 1.
        dec_out_2 = self.activation_fn(self.decoder2_norm_a(self.decoder2_conv_a(decoder_input_2))) # First conv, batch norm, activation.
        dec_out_2 = self.decoder2_norm_b(self.decoder2_conv_b(dec_out_2)) # Second conv and batch norm.
        dec_out_2 = dec_out_2 + block_time_embeddings['dec2'] # Add time embedding.

        # --- Final Output ---
        return self.output_conv(dec_out_2) # Return the final output, typically predicted noise.

In [4]:
# Hyperparameters
BETA_START = 0.0001  # Start value for noise schedule
BETA_END = 0.02  # End value for noise schedule

In [5]:
import torch
import math # Assuming math is needed for BETA_START/BETA_END calculations if not constants
# BETA_START and BETA_END should be defined globally or passed as arguments
# Example: BETA_START = 0.0001; BETA_END = 0.02

def add_noise_at_timestep(x_start, t, timesteps=1000):
    """
    Adds noise to images at a specific timestep 't' according to the DDPM forward diffusion process.
    """
    noise = torch.randn_like(x_start) # Generate random noise with the same shape as the input image.

    # Define and move the noise schedule (betas) to the input tensor's device.
    betas = torch.linspace(BETA_START, BETA_END, timesteps, device=x_start.device)
    alphas_cumprod = torch.cumprod(1.0 - betas, dim=0) # Compute cumulative product of alphas (alpha_bar).

    # Extract coefficients for the given timesteps 't' and reshape for broadcasting.
    sqrt_alpha_bar_t = alphas_cumprod[t].sqrt().view(-1, 1, 1, 1) # Coefficient for the original image.
    sqrt_one_minus_alpha_bar_t = (1.0 - alphas_cumprod[t]).sqrt().view(-1, 1, 1, 1) # Coefficient for the noise.

    # Apply the forward diffusion equation to produce noisy images.
    noisy_images = sqrt_alpha_bar_t * x_start + sqrt_one_minus_alpha_bar_t * noise # Linearly combine original image and noise.

    return noisy_images, noise # Return the generated noisy image and the specific noise that was added.

In [6]:
import torch
import torch.nn.functional as F # Ensure F is imported for F.mse_loss

def diffusion_loss_fn(model, x_start, timesteps=1000):
    """
    Calculates the diffusion loss across multiple timesteps for each image in the batch.
    """
    batch_size = x_start.shape[0] # Get the batch size from the input images.
    total_batch_loss = torch.zeros(batch_size, device=x_start.device) # Initialize accumulated loss per image.

    # Loop a fixed number of times to average loss over different timesteps.
    for _ in range(10): # Sample 10 distinct timesteps for each image to increase training robustness.
        # Randomly sample timesteps for each image in the batch.
        sampled_timesteps = torch.randint(1, timesteps, (batch_size,), device=x_start.device)

        # Generate noisy versions of the original images and retrieve the exact noise added.
        noisy_x, true_noise = add_noise_at_timestep(x_start, sampled_timesteps, timesteps)

        # Have the model predict the noise given the noisy image and timestep.
        predicted_noise = model(noisy_x, sampled_timesteps)

        # Calculate the Mean Squared Error (MSE) between predicted and true noise.
        # 'reduction='none'' keeps individual losses for each element in the batch, channel, height, width.
        per_element_loss = F.mse_loss(predicted_noise, true_noise, reduction='none')
        # Average the loss across spatial and channel dimensions to get a single loss per image.
        per_image_step_loss = per_element_loss.mean(dim=(1, 2, 3))

        total_batch_loss += per_image_step_loss # Add this step's loss to the total for each image.

    mean_loss_per_image = total_batch_loss / 10 # Average the accumulated loss over the 10 samples.

    return mean_loss_per_image # Return the final average loss per image in the batch.

## Data preparation

In [7]:
from torchvision import datasets, transforms

# Define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize all images to 128x128 pixels.
    transforms.ToTensor(),  # Convert images to PyTorch tensors (scales pixel values to [0, 1]).
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize pixel values to the range [-1, 1].
])

In [8]:
# Define the root directory of your dataset.
# IMPORTANT: For ImageFolder to work, your images must be inside at least one subdirectory.
# For example, if your images are directly in 'D:\Users\VICTOR\Desktop\ADRL\Assignment 3\Butterfly dataset',
# you should create a subfolder like 'D:\Users\VICTOR\Desktop\ADRL\Assignment 3\Butterfly dataset\all_images\'
# and move all your butterfly images into 'all_images'.
data_dir = '/home/vishwa/data_large/GenAI/animal_face/train' # Adjust this path to your actual image subfolder.

# batch size reduced from 32 to 24 for 11GB GPU
batch_size = 8

# Create the dataset using torchvision.datasets.ImageFolder.
# ImageFolder automatically finds images in subdirectories and assigns labels based on folder names.
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Create the DataLoader for efficient batching and shuffling during training.
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Loaded {len(dataset)} images.") # Display the total number of images found.
print(f"Detected classes (subfolders): {dataset.classes}") # Show the class names (subdirectory names) ImageFolder found.

# Example of how to get a batch:
# for images, labels in dataloader:
#     print(f"Batch images shape: {images.shape}") # Should be (batch_size, 3, 128, 128)
#     print(f"Batch labels: {labels}") # Labels will correspond to the subfolder index (e.g., 0 for 'all_images')
#     break

Loaded 14630 images.
Detected classes (subfolders): ['cat', 'dog', 'wild']


## Model Training

In [9]:
import torch
import torch.nn as nn # For nn.Module, if not already imported
from torch.optim import Adam # For the Adam optimizer

# Assuming UNet, add_noise_at_timestep, diffusion_loss_fn,
# dataloader, and device (e.g., torch.device('cuda' if torch.cuda.is_available() else 'cpu')) are defined.

# Model and Training Setup
model = UNet().to(device) # Initialize the UNet model and move it to the specified device.
model.train() # Set the model to training mode (enables dropout, batch norm updates).

# Define training hyperparameters
num_epochs = 500 # Number of complete passes through the dataset.
learning_rate = 1e-4 # The rate at which model weights are adjusted during optimization.

# Initialize the Adam optimizer
optimizer = Adam(model.parameters(), lr=learning_rate) # Create an Adam optimizer instance for model's parameters.

# Assuming num_epochs, dataloader, optimizer, diffusion_loss_fn, model, and device are already defined

# Training loop implementation
for epoch in range(num_epochs): # Loop through each training epoch.
    total_epoch_loss = 0.0 # Initialize a variable to accumulate loss for the current epoch.
    num_batches = 0 # Initialize a counter for the number of batches in the epoch.

    # print(f"Epoch {epoch+1}/{num_epochs}") # Removed: Display current epoch progress.

    for batch_idx, (images, _) in enumerate(dataloader): # Iterate over batches from the dataloader.
        optimizer.zero_grad() # Clear previous gradients before a new backward pass.

        images = images.to(device) # Move the current batch of images to the active device.

        # Calculate diffusion loss
        batch_losses = diffusion_loss_fn(model, images) # Compute loss for each image in the batch.
        loss = batch_losses.mean() # Get the average loss across the current batch.
        
        total_epoch_loss += loss.item() # Add the current batch's loss to the total epoch loss.
        num_batches += 1 # Increment the batch counter.

        # Removed: print(f"Batch {batch_idx+1}, Loss: {loss.item():.4f}") # Print batch progress and loss value.

        loss.backward() # Perform backpropagation to compute gradients.
        optimizer.step() # Update model parameters using the optimizer.

        # Clear GPU cache (optional, for memory management)
        if batch_idx % 10 == 0: # Check condition every 10 batches.
            torch.cuda.empty_cache() # Release unused GPU memory.
    
    # Calculate and print the average epoch loss
    if num_batches > 0:
        avg_epoch_loss = total_epoch_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Epoch Loss: {avg_epoch_loss:.4f}")
    else:
        print(f"Epoch {epoch+1}/{num_epochs}, No batches processed.")

    # Save model checkpoint
    torch.save(model.state_dict(), f'unet_model_epoch_{epoch+1}.pth') # Save the model's learned weights after each epoch.

Epoch 1/500, Average Epoch Loss: 0.0733
Epoch 2/500, Average Epoch Loss: 0.0349
Epoch 3/500, Average Epoch Loss: 0.0302
Epoch 4/500, Average Epoch Loss: 0.0277
Epoch 5/500, Average Epoch Loss: 0.0260
Epoch 6/500, Average Epoch Loss: 0.0245
Epoch 7/500, Average Epoch Loss: 0.0239
Epoch 8/500, Average Epoch Loss: 0.0234
Epoch 9/500, Average Epoch Loss: 0.0231
Epoch 10/500, Average Epoch Loss: 0.0226
Epoch 11/500, Average Epoch Loss: 0.0221
Epoch 12/500, Average Epoch Loss: 0.0222
Epoch 13/500, Average Epoch Loss: 0.0218
Epoch 14/500, Average Epoch Loss: 0.0218
Epoch 15/500, Average Epoch Loss: 0.0217
Epoch 16/500, Average Epoch Loss: 0.0215
Epoch 17/500, Average Epoch Loss: 0.0212
Epoch 18/500, Average Epoch Loss: 0.0215
Epoch 19/500, Average Epoch Loss: 0.0210
Epoch 20/500, Average Epoch Loss: 0.0210
Epoch 21/500, Average Epoch Loss: 0.0208
Epoch 22/500, Average Epoch Loss: 0.0208
Epoch 23/500, Average Epoch Loss: 0.0208
Epoch 24/500, Average Epoch Loss: 0.0209
Epoch 25/500, Average Epo

KeyboardInterrupt: 