## Prelimnaries

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'  # Prevent OpenMP initialization error

In [None]:
# Importing the necessary libraries
import torch  # Import PyTorch library
import torch.nn as nn  # Import neural network module
import torch.optim as optim  # Import optimization module
import math  # Import math module for log calculations
from torchvision import datasets, transforms  # Import datasets and transforms
from torchvision.utils import save_image, make_grid  # Import utility to save images
import torchvision  # Import torchvision library
import matplotlib.pyplot as plt  # Import plotting library
import os  # Import os module for file operations
import numpy as np  # Import numpy library
from torch.utils.data import Dataset  # Add this import at the top
from PIL import Image  # Import PIL Image module for image handling
import torch.nn.functional as F  # Import PyTorch's functional API for loss functions

In [None]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Set device
print(f"Using device: {device}")  # Print the device being used

## Define models

In [None]:
class UNet(nn.Module):
    """U-Net architecture for noise prediction in diffusion models with built-in residual connections, optimized for 128-channel 8x8 input"""
    def __init__(self, in_channels=128, time_dim=256):  # Modified input channels to 128 for 8x8 input
        super().__init__()

        # Pooling and activation layers used throughout the network
        self.pool = nn.MaxPool2d(2)  # Max pooling for downsampling
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # Bilinear upsampling
        self.relu = nn.ReLU()  # ReLU activation function

        # Encoder Block 1 - Input level (8x8)
        self.enc1_conv1 = nn.Conv2d(in_channels, 256, 3, padding=1)  # First conv: (N,128,8,8) -> (N,256,8,8)
        self.enc1_bn1 = nn.BatchNorm2d(256)  # Normalizes each of the 256 channels independently
        self.enc1_conv2 = nn.Conv2d(256, 256, 3, padding=1)  # Second conv: (N,256,8,8) -> (N,256,8,8)
        self.enc1_bn2 = nn.BatchNorm2d(256)  # Batch norm after second conv

        # Encoder Block 2 - After first pooling (4x4)
        self.enc2_conv1 = nn.Conv2d(256, 512, 3, padding=1)  # First conv: (N,256,4,4) -> (N,512,4,4)
        self.enc2_bn1 = nn.BatchNorm2d(512)  # Batch norm after first conv
        self.enc2_conv2 = nn.Conv2d(512, 512, 3, padding=1)  # Second conv: (N,512,4,4) -> (N,512,4,4)
        self.enc2_bn2 = nn.BatchNorm2d(512)  # Batch norm after second conv

        # Encoder Block 3 - After second pooling (2x2)
        self.enc3_conv1 = nn.Conv2d(512, 1024, 3, padding=1)  # First conv: (N,512,2,2) -> (N,1024,2,2)
        self.enc3_bn1 = nn.BatchNorm2d(1024)  # Batch norm after first conv
        self.enc3_conv2 = nn.Conv2d(1024, 1024, 3, padding=1)  # Second conv: (N,1024,2,2) -> (N,1024,2,2)
        self.enc3_bn2 = nn.BatchNorm2d(1024)  # Batch norm after second conv

        # Decoder Block 3 - First upsampling (2x2 -> 4x4)
        self.dec3_conv1 = nn.Conv2d(1536, 512, 3, padding=1)  # First conv: (N,1536,4,4) -> (N,512,4,4)
        self.dec3_bn1 = nn.BatchNorm2d(512)  # Batch norm after first conv
        self.dec3_conv2 = nn.Conv2d(512, 512, 3, padding=1)  # Second conv: (N,512,4,4) -> (N,512,4,4)
        self.dec3_bn2 = nn.BatchNorm2d(512)  # Batch norm after second conv

        # Decoder Block 2 - Second upsampling (4x4 -> 8x8)
        self.dec2_conv1 = nn.Conv2d(768, 256, 3, padding=1)  # First conv: (N,768,8,8) -> (N,256,8,8)
        self.dec2_bn1 = nn.BatchNorm2d(256)  # Batch norm after first conv
        self.dec2_conv2 = nn.Conv2d(256, 256, 3, padding=1)  # Second conv: (N,256,8,8) -> (N,256,8,8)
        self.dec2_bn2 = nn.BatchNorm2d(256)  # Batch norm after second conv

        # Final output layer
        self.final_conv = nn.Conv2d(256, in_channels, kernel_size=1)  # Final conv: (N,256,8,8) -> (N,128,8,8)

        # Time embedding dimension and projection
        self.time_dim = time_dim  # Time embedding dimension

        # Define MLPs as model parameters
        self.time_enc1 = nn.Sequential(nn.Linear(time_dim, 256), nn.SiLU(), nn.Linear(256, 256))  # Time embedding MLP for encoder block 1
        self.time_enc2 = nn.Sequential(nn.Linear(time_dim, 512), nn.SiLU(), nn.Linear(512, 512))  # Time embedding MLP for encoder block 2
        self.time_enc3 = nn.Sequential(nn.Linear(time_dim, 1024), nn.SiLU(), nn.Linear(1024, 1024))  # Time embedding MLP for encoder block 3
        self.time_dec3 = nn.Sequential(nn.Linear(time_dim, 512), nn.SiLU(), nn.Linear(512, 512))  # Time embedding MLP for decoder block 3
        self.time_dec2 = nn.Sequential(nn.Linear(time_dim, 256), nn.SiLU(), nn.Linear(256, 256))  # Time embedding MLP for decoder block 2

    def get_time_embedding(self, t):
        """Generate sinusoidal time embedding and project through MLPs for each block

        Args:
            t: Time tensor of shape (batch_size, 1)

        Returns:
            Dictionary containing time embeddings for each block
        """
        half_dim = self.time_dim // 2  # Calculate half dimension for sin/cos embeddings
        embeddings = torch.arange(half_dim, device=t.device).float()  # Create position indices
        embeddings = torch.exp(-math.log(10000) * embeddings / half_dim)  # Calculate frequency bands
        embeddings = t * embeddings.unsqueeze(0)  # Shape: (batch_size, half_dim)
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)  # Shape: (batch_size, time_dim)

        # Use the class MLPs instead of creating new ones
        t_emb_enc1 = self.time_enc1(embeddings).unsqueeze(-1).unsqueeze(-1)  # Use class MLP
        t_emb_enc2 = self.time_enc2(embeddings).unsqueeze(-1).unsqueeze(-1)  # Use class MLP
        t_emb_enc3 = self.time_enc3(embeddings).unsqueeze(-1).unsqueeze(-1)  # Use class MLP
        t_emb_dec3 = self.time_dec3(embeddings).unsqueeze(-1).unsqueeze(-1)  # Use class MLP
        t_emb_dec2 = self.time_dec2(embeddings).unsqueeze(-1).unsqueeze(-1)  # Use class MLP

        return {
            'enc1': t_emb_enc1,  # Time embedding for encoder block 1
            'enc2': t_emb_enc2,  # Time embedding for encoder block 2
            'enc3': t_emb_enc3,  # Time embedding for encoder block 3
            'dec3': t_emb_dec3,  # Time embedding for decoder block 3
            'dec2': t_emb_dec2   # Time embedding for decoder block 2
        }

    def forward(self, x, t):
        """Forward pass through U-Net optimized for 8x8 input with time embeddings at each block"""
        # Time embedding
        t = t.unsqueeze(-1).float()  # Ensure time is in correct shape
        t_embs = self.get_time_embedding(t)  # Get time embeddings for each block

        # Encoder pathway with skip connections and time embeddings
        # Encoder Block 1 (8x8)
        e1 = self.relu(self.enc1_bn1(self.enc1_conv1(x)))  # First conv layer
        e1 = self.relu(self.enc1_bn2(self.enc1_conv2(e1)))  # Second conv layer with ReLU
        e1 = e1 + t_embs['enc1']  # Add time embedding to encoder block 1

        # Encoder Block 2 (4x4)
        e2 = self.relu(self.enc2_bn1(self.enc2_conv1(self.pool(e1))))  # First conv layer
        e2 = self.relu(self.enc2_bn2(self.enc2_conv2(e2)))  # Second conv layer with ReLU
        e2 = e2 + t_embs['enc2']  # Add time embedding to encoder block 2

        # Encoder Block 3 (2x2)
        e3 = self.relu(self.enc3_bn1(self.enc3_conv1(self.pool(e2))))  # First conv layer
        e3 = self.relu(self.enc3_bn2(self.enc3_conv2(e3)))  # Second conv layer with ReLU
        e3 = e3 + t_embs['enc3']  # Add time embedding to encoder block 3

        # Decoder pathway using skip connections
        # Decoder Block 3 (4x4)
        d3 = torch.cat([self.upsample(e3), e2], dim=1)  # Concatenate along channel dimension
        d3 = self.relu(self.dec3_bn1(self.dec3_conv1(d3)))  # First conv block
        d3 = self.dec3_bn2(self.dec3_conv2(d3))  # Second conv block
        d3 = d3 + t_embs['dec3']  # Add time embedding to decoder block 3

        # Decoder Block 2 (8x8)
        d2 = torch.cat([self.upsample(d3), e1], dim=1)  # Concatenate along channel dimension
        d2 = self.relu(self.dec2_bn1(self.dec2_conv1(d2)))  # First conv block
        d2 = self.dec2_bn2(self.dec2_conv2(d2))  # Second conv block
        d2 = d2 + t_embs['dec2']  # Add time embedding to decoder block 2

        return self.final_conv(d2)  # Return final output (N,128,8,8)

## Diffusion process

The forward diffusion process gradually adds Gaussian noise to an tensor over multiple timesteps. However we can get the noisy tensor directly from the original tensor without going through the intermediate steps without any model as follows:

$$x_t = \sqrt{\bar{\alpha_t}} \cdot x_0 + \sqrt{1 - \bar{\alpha_t}} \cdot \epsilon$$

where:
- $x_t$ is the noisy tensor at timestep $t$
- $x_0$ is the original clean tensor  
- $\bar{\alpha_t}$ (alpha_bar) is the cumulative product of $\alpha_i = (1-\beta_i)$ up to timestep $t$, i.e. $\bar{\alpha}_t = \prod_{i=1}^t \alpha_i$
- $\epsilon$ (epsilon) is random Gaussian noise

This process gradually transforms the data distribution into pure Gaussian noise at $t=T$.

In [None]:
# Hyperparameters
BETA_START = 0.0001  # Start value for noise schedule
BETA_END = 0.02  # End value for noise schedule

In [None]:
def add_noise_at_timestep(x_start, t, timesteps=1000):
    """
    Add noise to tensor at timestep t according to diffusion process

    Args:
        x_start (torch.Tensor): Original tensor
        t (torch.Tensor): Timesteps
        timesteps (int): Total number of diffusion steps

    Returns:
        tuple: Noisy tensor and noise
    """
    device = x_start.device  # Get device from input tensor
    noise = torch.randn_like(x_start)  # Generate random noise on same device as x_start

    betas = torch.linspace(BETA_START, BETA_END, timesteps).to(device)  # Move noise schedule to device
    alphas = 1 - betas  # Alpha values
    alphas_cumprod = torch.cumprod(alphas, dim=0)  # Cumulative product of alphas

    # Extract relevant alpha values for timestep t
    sqrt_alphas_cumprod_t = alphas_cumprod[t].sqrt()  # Get sqrt(alpha_bar) for timestep t
    sqrt_one_minus_alphas_cumprod_t = (1 - alphas_cumprod[t]).sqrt()  # Get sqrt(1-alpha_bar) for timestep t

    # Reshape for broadcasting
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, 1, 1, 1)  # Shape: (batch_size, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, 1, 1, 1)  # Shape: (batch_size, 1, 1, 1)

    # Apply noise using the diffusion equation
    noisy_images = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise  # Forward diffusion step

    return noisy_images, noise  # Return both noisy images and the noise added

## Define the loss function

Its simply mean square error loss

In [None]:
def diffusion_loss_fn(model, x_start, timesteps=1000):
    """
    Calculate the diffusion loss across multiple timesteps for each image in batch

    Args:
        model (nn.Module): The UNet model for noise prediction
        x_start (torch.Tensor): Original clean tensor (batch_size, channels, height, width)
        timesteps (int): Total number of diffusion steps # Number of diffusion steps

    Returns:
        torch.Tensor: Mean loss per image in batch (batch_size,)
    """
    batch_size = x_start.shape[0] # Get batch size from input images

    # Sample random timesteps for each image in the batch
    t = torch.randint(1, timesteps, (batch_size,), device=x_start.device) # Shape: (batch_size,)

    # Add noise to the input images for the sampled timesteps
    noisy_images, noise = add_noise_at_timestep(x_start, t, timesteps) # Shapes: (batch_size, channels, H, W)

    # Predict the noise using the model
    predicted_noise = model(noisy_images, t) # Shape: (batch_size, channels, H, W)

    # Calculate MSE loss between predicted and actual noise per image
    loss = F.mse_loss(predicted_noise, noise, reduction='none') # Shape: (batch_size, channels, H, W)
    loss = loss.mean(dim=(1,2,3)) # Shape: (batch_size,) - Average over channels, height, width

    return loss # Return loss per image in batch

## Data preparation

### Define VQ VAE

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),  # Input: N x 3 x 128 x 128, Output: N x 64 x 64 x 64, 64 filters of size 4x4x3
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 64 x 64, Output: N x 128 x 32 x 32, 128 filters of size 4x4x64
            nn.BatchNorm2d(128),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),  # Input: N x 128 x 32 x 32, Output: N x 256 x 16 x 16, 256 filters of size 4x4x128
            nn.BatchNorm2d(256),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),  # Input: N x 256 x 16 x 16, Output: N x 512 x 8 x 8, 512 filters of size 4x4x256
            nn.BatchNorm2d(512),  # Applies Batch Normalization to the output of the previous layer
            nn.LeakyReLU(0.2),  # Applies Leaky ReLU activation function with negative slope of 0.2
        )

        self.final_conv = nn.Conv2d(in_channels=512, out_channels=latent_dim, kernel_size=3, stride=1, padding=1)  # Input: N x 512 x 8 x 8, Output: N x latent_dim x 8 x 8, latent_dim filters of size 3x3x512

    def forward(self, x):
        encoded = self.encoder(x)  # Apply main encoder layers
        latents = self.final_conv(encoded)  # Generate latent vectors
        return latents

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self.embedding_dim = embedding_dim  # Dimension of each embedding vector
        self.num_embeddings = num_embeddings  # Number of embedding vectors in the codebook
        self.commitment_cost = commitment_cost  # Coefficient for the commitment loss

        # Initialize the embedding vectors (codebook)
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)  # Creates an embedding layer to store the codebook
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)  # Initialize embedding weights uniformly

    def forward(self, inputs):
        # Convert inputs from BCHW to BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()  # Rearrange dimensions from BCHW to BHWC

        input_shape = inputs.shape  # Store original input shape

        # Reshape inputs to (batch_size * height * width, channels)
        flat_input = inputs.view(-1, self.embedding_dim)  # Flatten input to 2D tensor

        # Compute L2 distances between flattened input and embedding vectors
        distances = torch.sum(flat_input**2, dim=1, keepdim=True) + \
                    torch.sum(self.embedding.weight**2, dim=1) - \
                    2 * torch.matmul(flat_input, self.embedding.weight.t())  # Calculate distances using the formula: ||x-y||^2 = ||x||^2 + ||y||^2 - 2x^T y

        # Find nearest embedding for each input vector
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)  # Find index of nearest embedding for each input vector
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)  # Create one-hot encodings
        encodings.scatter_(1, encoding_indices, 1)  # Set the corresponding index to 1 for each encoding

        # Quantize the input vectors
        quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)  # Multiply encodings with embedding weights and reshape to original input shape

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized.detach(), inputs)  # Commitment loss: how far are the inputs from their quantized values
        embedding_loss = F.mse_loss(quantized, inputs.detach())  # Embedding loss: how far are the quantized values from the inputs
        vq_loss = embedding_loss + self.commitment_cost * commitment_loss  # Total VQ loss

        # Straight-through estimator
        quantized = inputs + (quantized - inputs).detach()  # Add quantization error to input (detached to avoid backpropagation through this path)

        # Convert quantized from BHWC back to BCHW
        quantized = quantized.permute(0, 3, 1, 2).contiguous()  # Rearrange dimensions from BHWC back to BCHW

        # Compute perplexity
        avg_probs = torch.mean(encodings, dim=0)  # Average probability of each encoding across the batch
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))  # Compute perplexity (add small epsilon to avoid log(0))

        return vq_loss, quantized, perplexity, encodings  # Return VQ loss, quantized vectors, perplexity, and encodings

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim  # Store the latent dimension for use in the forward pass

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 64, kernel_size=4, stride=2, padding=1),  # Input: N x latent_dim x 8 x 8, Output: N x 64 x 16 x 16, 64 filters of size 4x4xlatent_dim
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(64),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 16 x 16, Output: N x 32 x 32 x 32, 32 filters of size 4x4x64
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(32),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # Input: N x 32 x 32 x 32, Output: N x 16 x 64 x 64, 16 filters of size 4x4x32
            nn.ReLU(),  # Apply ReLU activation to introduce non-linearity
            nn.BatchNorm2d(16),  # Normalize the output to stabilize training

            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),  # Input: N x 16 x 64 x 64, Output: N x 3 x 128 x 128, 3 filters of size 4x4x16
            nn.Sigmoid()  # Apply Sigmoid activation to ensure output is in range [0, 1]
        )

    def forward(self, x):
        decoded_image = self.decoder(x)  # Pass the input through the decoder layers
        return decoded_image  # Return the decoded RGB image

In [None]:
class VQVAE(nn.Module):
    def __init__(self, embedding_dim, num_embeddings=512, commitment_cost=0.25):
        super(VQVAE, self).__init__()

        self.encoder = Encoder(embedding_dim)  # Initialize the encoder with the embedding dimension as latent dimension
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)  # Initialize the Vector Quantizer with default or provided values
        self.decoder = Decoder(embedding_dim)  # Initialize the decoder with the embedding dimension

    def forward(self, x):
        z = self.encoder(x)  # Encode the input
        vq_loss, quantized, perplexity, _ = self.vq_layer(z)  # Apply Vector Quantization
        x_recon = self.decoder(quantized)  # Decode the quantized representation

        return vq_loss, x_recon, perplexity  # Return VQ loss, reconstructed image, and perplexity

### Load VQVAE

In [None]:
# Load the saved model from .pth file
model_path = r'D:\Users\VICTOR\Desktop\ADRL\Assignment 3\vqvae_model_epoch_30.pth'  # Use absolute path to model file
try:
    loaded_model = torch.load(model_path)  # Attempt to load the entire model including architecture and weights
    print(f"Model loaded successfully from {model_path}")  # Confirm successful load
except OSError as e:
    print(f"Error loading model: {e}")  # Print the specific error message
    print("Please check if the file path is correct and the file exists.")
    raise  # Re-raise the exception to stop execution if the model can't be loaded

# Adjust these parameters to match the saved model
latent_dim = 128  # Define the latent dimension for the VQVAE model
num_embeddings = 1024  # Define the number of embeddings for the VQ layer
embedding_dim = 128  # Define the dimension of each embedding

# Create the model with the correct parameters
VQVAE_model = VQVAE(embedding_dim=embedding_dim, num_embeddings=num_embeddings)  # Initialize the VQVAE model with specified parameters

# Load the state dictionary
VQVAE_model.load_state_dict(torch.load(model_path))  # Load the saved model weights into the initialized model

# Move the model to the appropriate device
VQVAE_model = VQVAE_model.to(device)  # Move model to GPU if available

# Set the model to evaluation mode
VQVAE_model.eval()  # Set the model to evaluation mode, disabling dropout and using evaluation behavior for batch normalization

print(VQVAE_model)  # Print the model architecture to verify it's loaded correctly

### Get latent representations

In [None]:
def get_vq_representations(dataloader, VQVAE_model, device):
    """
    Generate vector quantized representations from a VQVAE model

    Args:
        dataloader: DataLoader containing images
        VQVAE_model: Trained VQVAE model
        device: Device to run computations on (CPU/GPU)

    Returns:
        numpy.ndarray: Array of quantized tensors # Comment explaining return value
    """
    vq_representations = [] # Initialize list to store VQ representations

    VQVAE_model.eval() # Set model to evaluation mode for inference

    with torch.no_grad(): # Disable gradient computation for efficiency
        for batch_idx, (images, _) in enumerate(dataloader): # Iterate through batches, ignore labels
            images = images.to(device) # Move batch of images to device

            # Get encoded representations
            encoded = VQVAE_model.encoder(images) # Encode images through encoder

            # Apply vector quantization
            _, quantized, _, _ = VQVAE_model.vq_layer(encoded) # Get quantized vectors from VQ layer

            # Store results
            vq_representations.append(quantized.cpu().numpy()) # Add quantized representations to list

            if (batch_idx + 1) % 10 == 0: # Print progress every 10 batches
                print(f"Processed {batch_idx + 1} batches...") # Print progress message

    # Combine all batches
    vq_representations = np.concatenate(vq_representations, axis=0) # Stack all VQ representations into single array

    # Print final shape
    print(f"Final shape of VQ representations: {vq_representations.shape}") # Print shape of VQ tensor

    return vq_representations # Return VQ representations

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir  # Directory containing all images
        self.transform = transform  # Transformations to apply to images
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]  # List all image files

    def __len__(self):
        return len(self.image_files)  # Return the total number of images

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])  # Get path of image at index idx
        image = Image.open(img_path).convert('RGB')  # Open image and convert to RGB

        if self.transform:
            image = self.transform(image)  # Apply transformations if any

        return image, 0  # Return image and a dummy label (0)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize the image to 128x128 pixels
    transforms.ToTensor(),  # Convert the PIL Image to a tensor, scales to [0, 1]
])

In [None]:
# Load the Butterfly dataset from local machine
data_dir = r'D:\Users\VICTOR\Desktop\ADRL\Assignment 3\Butterfly dataset'  # Path to the Butterfly dataset

# Create dataset and dataloader
dataset = CustomImageDataset(root_dir=data_dir, transform=transform)  # Use our custom dataset class
batch_size = 64  # Set batch size to 64
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)  # Create dataloader with batch size of 64

print(f"Loaded {len(dataset)} images.")  # Print the total number of images loaded

In [None]:
# Get VQ representations for all images
print("Getting VQ representations for all images...")  # Print status message

all_vq_representations = []  # Initialize list to store VQ representations

with torch.no_grad():  # Disable gradient computation
    for batch_idx, (images, _) in enumerate(dataloader):  # Iterate through batches
        images = images.to(device)  # Move images to device

        # Get encoded representations directly
        encoded = VQVAE_model.encoder(images)  # Encode images using loaded VQVAE model
        _, quantized, _, _ = VQVAE_model.vq_layer(encoded)  # Get quantized vectors

        # Store results
        all_vq_representations.append(quantized.cpu().numpy())  # Add to list
        print(f"Processed batch {batch_idx+1}/{len(dataloader)}")  # Print progress

# Concatenate all batches into single array
final_vq_representations = np.concatenate(all_vq_representations, axis=0)  # Combine all batches

print(f"Final VQ representations shape: {final_vq_representations.shape}")  # Print final shape

## Training loop

In [None]:
# Initialize UNet model for image segmentation and move to device
u_net_model = UNet()  # Create UNet model instance for image segmentation
u_net_model = u_net_model.to(device)  # Move model to GPU if available

In [None]:
u_net_model.train()  # Enable training mode for U-Net model (activates dropout, batch norm, etc.)

In [None]:
# Define training hyperparameters
num_epochs = 100  # Number of epochs to train for
learning_rate = 1e-4  # Learning rate for optimizer

In [None]:
# Initialize optimizer
optimizer = torch.optim.Adam(u_net_model.parameters(), lr=learning_rate)  # Initialize Adam optimizer for U-Net model

In [None]:
# Create dataset and dataloader for VQ representations
vq_dataset = torch.utils.data.TensorDataset(torch.from_numpy(final_vq_representations))  # Create dataset from VQ representations
vq_dataloader = torch.utils.data.DataLoader(vq_dataset, batch_size=32, shuffle=True)  # Create dataloader with batch size 32

In [None]:
# Training loop implementation
u_net_model.train()  # Set UNet to training mode

for epoch in range(num_epochs):  # Iterate through epochs
    print(f"Epoch {epoch+1}/{num_epochs}")  # Print current epoch progress
    epoch_losses = []  # Track losses for this epoch

    for batch_idx, (vq_batch,) in enumerate(vq_dataloader):  # Iterate through VQ representation batches
        optimizer.zero_grad()  # Reset gradients for this batch
        vq_batch = vq_batch.to(device).float()  # Move VQ representations to device and ensure float type

        # Calculate diffusion loss using the modified loss function
        batch_losses = diffusion_loss_fn(u_net_model, vq_batch)  # Use u_net_model instead of VQVAE model
        loss = batch_losses.mean()  # Average loss across batch

        # Backward pass and optimization
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        # Track and print progress
        epoch_losses.append(loss.item())  # Store loss value
        if (batch_idx + 1) % 10 == 0:  # Print every 10 batches
            print(f"Batch {batch_idx+1}, Loss: {loss.item():.4f}")  # Print current batch loss

    # Print epoch summary
    avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)  # Calculate average epoch loss
    print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")  # Print epoch summary

    # Save model checkpoint after each epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': u_net_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_epoch_loss,
    }, f'unet_model_epoch_{epoch+1}.pth')  # Save model checkpoint

## Inference

The inference process in DDPM involves reversing the diffusion process to generate new samples. The steps are as follows:

1. Sample from the prior: Start by sampling $x_T \sim \mathcal{N}(0, I)$, which is the prior distribution.

2. Reverse the diffusion process: Sequentially get $x_{t-1}$ from xt for $t = T, T-1, \ldots, 1$ using
   
    $$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon(x_t, t)) + \sigma_t z$$

    where $z \sim \mathcal{N}(0, I)$ and $\sigma_t^2 = \beta_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}(1-\alpha_t)$

3. Obtain the final sample: The final sample x0 is obtained after completing the reverse process.


In [None]:
def get_beta_schedule(num_timesteps):
    """
    Get the beta schedule for the diffusion process

    Args:
        num_timesteps (int): Number of timesteps in diffusion process

    Returns:
        torch.Tensor: Beta schedule tensor
    """
    # Create linear schedule from BETA_START to BETA_END
    return torch.linspace(BETA_START, BETA_END, num_timesteps).to('cuda')  # Return beta schedule on GPU

In [None]:
def generate_sample(model, num_timesteps, device='cuda'):
    """
    Generate a new latent sample optimized for VQVAE latent space

    Args:
        model: Trained UNet model for noise prediction
        num_timesteps: Number of diffusion steps
        device: Computing device

    Returns:
        torch.Tensor: Generated latent sample
    """
    # Initialize with smaller noise scale for better VQVAE compatibility
    x_t = torch.randn(1, 128, 8, 8).to(device) * 0.25  # Reduced initial noise scale

    # Get diffusion parameters
    betas = get_beta_schedule(num_timesteps)  # Get noise schedule
    alphas = 1 - betas  # Calculate alphas
    alphas_cumprod = torch.cumprod(alphas, dim=0)  # Calculate cumulative product

    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        for t in reversed(range(num_timesteps)):  # Reverse diffusion process
            t_tensor = torch.tensor([t], device=device)  # Current timestep
            predicted_noise = model(x_t, t_tensor)  # Predict noise

            # Get current timestep parameters
            alpha_t = alphas[t]  # Current alpha
            alpha_t_bar = alphas_cumprod[t]  # Current cumulative alpha
            beta_t = betas[t]  # Current beta

            # Calculate noise scale with reduced magnitude
            sigma_t = torch.sqrt(beta_t) * 0.3  # Reduced noise scale

            # Add noise only for non-final steps
            if t > 0:
                noise = torch.randn_like(x_t) * sigma_t  # Generate scaled noise
            else:
                noise = 0  # No noise at final step

            # Reverse diffusion step
            x_t = (1 / torch.sqrt(alpha_t)) * (
                x_t - ((1 - alpha_t) / torch.sqrt(1 - alpha_t_bar)) * predicted_noise
            ) + noise

            # Clamp values to prevent extremes
            x_t = torch.clamp(x_t, -2, 2)  # Keep values in reasonable range

    return x_t  # Return generated sample

In [None]:
# Load the trained model and generate a single sample
model_path = r"D:\Users\VICTOR\Desktop\ADRL\Assignment 3\unet_model_epoch_50.pth"  # Path to saved model checkpoint
loaded_unet_model = UNet().to('cuda')  # Initialize model and move to GPU

# Load the checkpoint with weights_only=True for security
checkpoint = torch.load(model_path, weights_only=True)  # Load checkpoint safely

# Handle both direct state dict and nested checkpoint formats
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:  # Check if nested format
    loaded_unet_model.load_state_dict(checkpoint['model_state_dict'])  # Load nested state dict
else:
    loaded_unet_model.load_state_dict(checkpoint)  # Load direct state dict

# Set model to evaluation mode
loaded_unet_model.eval()  # Important for inference

In [None]:
# Generate a single sample
num_timesteps = 1000  # Number of timesteps for diffusion

# Load the saved VQVAE model
vqvae_model = VQVAE(embedding_dim=embedding_dim, num_embeddings=num_embeddings).to('cuda')  # Initialize model
vqvae_checkpoint = torch.load('vqvae_model_epoch_30.pth', weights_only=True)  # Load checkpoint
vqvae_model.load_state_dict(vqvae_checkpoint)  # Load state dictionary
vqvae_model.eval()  # Set to evaluation mode

num_samples = 100  # Number of samples to generate # Define number of samples to generate in loop

for i in range(num_samples):  # Loop through desired number of samples
    generated_sample = generate_sample(loaded_unet_model, num_timesteps, device='cuda')  # Generate one sample

    # Print sample dimensions
    print(f"Generated sample {i+1} dimensions: {generated_sample.shape}")  # Display tensor dimensions for current sample

    # Move generated sample to same device as VQVAE
    generated_sample = generated_sample.to('cuda')  # Ensure sample is on GPU

    with torch.no_grad():
        # Pass through VQ layer first to quantize the latents
        _, quantized, _, _ = vqvae_model.vq_layer(generated_sample)  # Quantize the generated latents using VQVAE quantizer

        # Then decode the quantized representation
        decoded_image = vqvae_model.decoder(quantized)  # Decode the quantized latents into an image

    # Decode the generated latent sample
    with torch.no_grad():  # Disable gradient computation
        decoded_image = vqvae_model.decoder(generated_sample)  # Pass through VQVAE decoder

    # Save the generated image with unique name
    save_path = f"generated_diffusion_sample_{i+1}.png"  # Define save path with sample number
    torchvision.utils.save_image(decoded_image[0], save_path)  # Save image to file
    print(f"Image {i+1} saved to {save_path}")  # Confirm save location for current sample

## FID

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((299, 299)),  # Resize images to Inception v3 input size
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet stats ie we will also preprocess the data in same way as the original imagenet dataset that was used to train Inception v3
])  # Define preprocessing steps for Inception v3 input

def prepare_inception_input(images):
    return preprocess(images)  # Apply preprocessing to the input images

In [None]:
from torchvision.models import inception_v3  # Import inception_v3 model from torchvision

def load_inception_model():
    """
    Loads and prepares a pre-trained Inception v3 model for feature extraction

    Returns:
        model: Modified Inception v3 model with final classification layer removed
    """
    model = inception_v3(pretrained=True, transform_input=False)  # Load pre-trained Inception v3 model without input transformation
    model.fc = torch.nn.Identity()  # Removes the last fully connected layer
    model.eval()  # Set the model to evaluation mode
    return model.to(device)  # Move the model to the same device as our GAN

inception_model = load_inception_model()  # Load and prepare the modified Inception v3 model

In [None]:
import glob  # Import glob module for file path operations

def extract_inception_features(image_paths):
    """
    Extracts features from images using Inception v3 model

    Args:
        image_paths: List of paths to image files

    Returns:
        features: Tensor of extracted features from all images
    """
    all_features = []  # Initialize list to store features from all images
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Resize((299, 299)),  # Resize to inception input size
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
    ])  # Define transformations for input images

    with torch.no_grad():  # Disable gradient computation for inference
        for img_path in image_paths:  # Process each image
            img = Image.open(img_path).convert('RGB')  # Load and convert image to RGB
            img_tensor = transform(img).unsqueeze(0).to(device)  # Transform and add batch dimension
            features = inception_model(img_tensor)  # Extract features using inception model
            all_features.append(features)  # Store features for current image

    return torch.cat(all_features, dim=0)  # Concatenate all features into single tensor

In [None]:
# Get list of generated image paths
generated_image_paths = glob.glob(r"D:\Users\VICTOR\Desktop\ADRL\Assignment 3\generated_samples\*.png")  # Get paths of all PNG files in directory

# Extract features from generated images
generated_features = extract_inception_features(generated_image_paths)  # Extract inception features from generated images
print(f"Extracted features from {len(generated_image_paths)} generated images")  # Print number of processed images
print(f"Feature shape: {generated_features.shape}")  # Print shape of extracted features

# Calculate statistics of generated features
generated_mean = torch.mean(generated_features, dim=0)  # Calculate mean across all samples for each feature dimension
generated_cov = torch.cov(generated_features.T)  # Calculate covariance matrix of features across samples

print(f"Generated features mean shape: {generated_mean.shape}")  # Print shape of mean vector
print(f"Generated features covariance shape: {generated_cov.shape}")  # Print shape of covariance matrix

# Save statistics for later comparison
torch.save(generated_mean, 'generated_mean.pt')  # Save mean vector to disk for future use
torch.save(generated_cov, 'generated_cov.pt')  # Save covariance matrix to disk for future use


In [None]:
# Get list of real image paths and verify they exist
real_image_paths = glob.glob(r"D:\Users\VICTOR\Desktop\ADRL\Assignment 3\Butterfly dataset\*.jpg")[:100]  # Get paths of 100 JPG files from real dataset
if len(real_image_paths) == 0:  # Check if any image paths were found
    raise ValueError("No image files found in the specified directory")  # Raise error if no images found

# Extract features from real images with error handling
try:
    real_features = extract_inception_features(real_image_paths)  # Extract inception features from real images
    print(f"Extracted features from {len(real_image_paths)} real images")  # Print number of processed images
    print(f"Feature shape: {real_features.shape}")  # Print shape of extracted features

    # Calculate statistics of real features
    real_mean = torch.mean(real_features, dim=0)  # Calculate mean across all samples for each feature dimension
    real_cov = torch.cov(real_features.T)  # Calculate covariance matrix of features across samples

    print(f"Real features mean shape: {real_mean.shape}")  # Print shape of mean vector
    print(f"Real features covariance shape: {real_cov.shape}")  # Print shape of covariance matrix

    # Save statistics for later comparison
    torch.save(real_mean, 'real_mean.pt')  # Save mean vector to disk for future use
    torch.save(real_cov, 'real_cov.pt')  # Save covariance matrix to disk for future use

except RuntimeError as e:
    print(f"Error processing images: {str(e)}")  # Print error message if feature extraction fails
    print("Please verify that all images are valid and accessible")  # Provide troubleshooting hint

In [None]:
def calculate_frechet_inception_distance(real_mean, real_cov, generated_mean, generated_cov):
    """
    Calculate the Fréchet Inception Distance (FID) between real and generated image features.

    Args:
    real_mean (torch.Tensor): Mean of real image features.
    real_cov (torch.Tensor): Covariance matrix of real image features.
    generated_mean (torch.Tensor): Mean of generated image features.
    generated_cov (torch.Tensor): Covariance matrix of generated image features.

    Returns:
    float: The calculated FID score.
    """

    # Convert to numpy for scipy operations
    real_mean_np = real_mean.cpu().numpy()  # Convert real mean to numpy array
    real_cov_np = real_cov.cpu().numpy()  # Convert real covariance to numpy array
    generated_mean_np = generated_mean.cpu().numpy()  # Convert generated mean to numpy array
    generated_cov_np = generated_cov.cpu().numpy()  # Convert generated covariance to numpy array

    # Calculate squared L2 norm between means
    mean_diff = np.sum((real_mean_np - generated_mean_np) ** 2)  # Compute squared difference between means

    # Calculate sqrt of product of covariances
    covmean = scipy.linalg.sqrtm(real_cov_np.dot(generated_cov_np))  # Compute matrix square root

    # Check and correct imaginary parts if necessary
    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Take only the real part if result is complex

    # Calculate trace term
    trace_term = np.trace(real_cov_np + generated_cov_np - 2 * covmean)  # Compute trace of the difference

    # Compute FID
    fid = mean_diff + trace_term  # Sum up mean difference and trace term

    return fid  # Return FID as a Python float

In [None]:
import scipy

# Move tensors to CPU and convert to numpy arrays
generated_mean_cpu = generated_mean.cpu()  # Move generated mean tensor to CPU
generated_cov_cpu = generated_cov.cpu()  # Move generated covariance tensor to CPU
real_mean_cpu = real_mean.cpu()  # Move real mean tensor to CPU
real_cov_cpu = real_cov.cpu()  # Move real covariance tensor to CPU

# Calculate FID score using CPU tensors
fid_score = calculate_frechet_inception_distance(real_mean_cpu, real_cov_cpu, generated_mean_cpu, generated_cov_cpu)  # Calculate FID using CPU tensors
print(f"Fréchet Inception Distance: {fid_score:.4f}")  # Print calculated FID score with 4 decimal places