In [2]:
# Import necessary PyTorch libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms


# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt
import numpy as np
from unet_decoder import UNetDecoder

In [3]:
def get_device():
    """Selects the best available device for PyTorch computations.

    Returns:
        torch.device: The selected device.
    """

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"using device: {device}")

using device: mps


In [4]:
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, Normalize, ToTensor,Resize

from torch.utils.data import DataLoader, random_split

# Define the transformation with resizing
transform = transforms.Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])

# Load the Omniglot dataset
dataset = datasets.MNIST(root='./data', download=True, transform=transform)

# Print the total number of images in the dataset
print(f"Total number of images in the dataset: {len(dataset)}")

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Print the number of images in the train and validation sets
print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

Total number of images in the dataset: 60000
Number of images in the training set: 48000
Number of images in the validation set: 12000


In [5]:
class DiffusionModel(nn.Module):
    def __init__(self, input_shape, T=1000):
        super(DiffusionModel, self).__init__()
        self.input_shape = input_shape
        self.T = T
        self.decoder = UNetDecoder(n_channels=input_shape[0])

        # Define the noise schedule
        self.alpha = self.create_noise_schedule(T)

    def create_noise_schedule(self, T):
        beta_start = 0.0001
        beta_end = 0.02
        betas = torch.linspace(beta_start, beta_end, T)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        return alphas_cumprod

    def forward(self, x):

        # Calculate the gaussian noise tensor
        batch_size = x.shape[0]
        epsilon = torch.randn(batch_size, 1, 28, 28).to(device)
        
        torch.cuda.empty_cache()

        #sample a timestep t
        t = np.random.randint(0, self.T)
        # Retrieve noise scheduler alpha_T
        alpha_t = self.alpha[t]

        # Calculate square root alphas
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        
        # Perform the weighted sum
        x_t = sqrt_alpha_t * x + sqrt_one_minus_alpha_t * epsilon

        torch.cuda.empty_cache()

        #Calculate the timestep tensor
        t = torch.tensor([t] * x_t.size(0), dtype=torch.long).to(x_t.device)

        # Perform the reconstruction process using Algorithm 2
        estimated_epsilon = self.decoder(x_t,t)
        torch.cuda.empty_cache()
        return epsilon, estimated_epsilon

In [6]:
import os
import torch

def save_checkpoint(epoch, model, optimizer, filename="checkpoint.pth"):
    """Saves the model and optimizer state at the specified path."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, filename)
    print(f"Checkpoint saved at epoch {epoch} to {filename}")

def load_checkpoint(model, optimizer, filename="checkpoint.pth", device='cpu'):
    """Loads the model and optimizer state from the specified path."""
    if os.path.isfile(filename):
        checkpoint = torch.load(filename, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        print(f"Checkpoint loaded from {filename}, resuming training from epoch {epoch}")
        return epoch
    else:
        print(f"No checkpoint found at {filename}, starting from scratch.")
        return -1

In [7]:
import time  # Importing time to log the duration
from tqdm import tqdm


def validate(model, val_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_val_loss = 0
    with torch.no_grad():  # Disable gradient computation during validation
        for data, _ in val_loader:
            data = data.to(device)
            epsilon, estimated_epsilon = model(data)
            reconstruction_loss = nn.functional.mse_loss(epsilon, estimated_epsilon)
            total_val_loss += reconstruction_loss.item()  # Accumulate the validation loss
    
    avg_val_loss = total_val_loss / len(val_loader)  # Calculate average loss
    return avg_val_loss

def train(model, optimizer, train_loader, device,start_epoch, num_epochs):
    model.train()

    for epoch in range(start_epoch+1, num_epochs):
        epoch_loss = 0.0
        epoch_start_time = time.time()  # Time tracking for the epoch

        print(f"Starting epoch {epoch+1}/{num_epochs}")
        for batch_idx, (data, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            data = data.to(device)

            # Forward pass
            epsilon, estimated_epsilon = model(data)
            total_loss = nn.functional.mse_loss(estimated_epsilon, epsilon)

            # Backward pass
            if not torch.isnan(total_loss).any():
                total_loss.backward()
            else:
                print(f"Warning: NaN detected in total_loss at batch {batch_idx+1}, skipping backward pass.")

            optimizer.step()  # Only step the optimizer every `accumulation_steps`
            optimizer.zero_grad()  # Reset gradients only after accumulation

            # Safe-guarding against NaN for epoch_loss
            if not torch.isnan(total_loss).any():
                epoch_loss += total_loss.item()
            else:
                print(f"NaN detected, not adding to epoch_loss at batch {batch_idx+1}")

        # Save the model checkpoint
        save_checkpoint(epoch, model, optimizer, filename="mnist_gaussian.pth")
        
        # Average loss after training for an epoch
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed in {time.time() - epoch_start_time:.2f} seconds, Avg Loss: {avg_loss}")

        # Validation phase
        avg_val_loss = validate(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] validation completed, Avg Validation Loss: {avg_val_loss}")

    return model


In [8]:
# Define the input shape
input_shape = (1, 28, 28)

# Create an instance of Gaussian Diffusion model
model = DiffusionModel(input_shape).to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

# Load the model training checkpoint
start_epoch = load_checkpoint(model, optimizer, filename="mnist_gaussian.pth")
print(f"The training ended in epoch number: {start_epoch}")

Checkpoint loaded from mnist_gaussian.pth, resuming training from epoch 199
The training ended in epoch number: 199


## Sampling according to DDPM

In [None]:
import torch
import torch.nn.functional as F

# Define the necessary parameters and variables
T = 1000
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1 - betas
alpha_bars = torch.cumprod(alphas, dim=0)
alpha_bar_T = alpha_bars[-1]

model.eval()

total_val_loss = 0.0
total_rate = 0.0
n_dim = 28 * 28

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(val_loader):
        data = data.to(device)
        x = torch.randn(data.shape[0], 1, 28, 28).to(device)
        x = torch.sqrt(alpha_bar_T) * data + torch.sqrt(1 - alpha_bar_T) * x
        rate = 0.0
        for s in range(T - 1, 0, -1):
            true_reverse_mean = (
                (torch.sqrt(alpha_bars[s - 1] * betas[s]) / (1 - alpha_bars[s])) * data 
                + ((1 - alpha_bars[s - 1]) / (1 - alpha_bars[s])) * torch.sqrt(alphas[s]) * x
            )
            t = torch.tensor([s] * x.size(0), dtype=torch.long).to(device)
            z = torch.randn(x.shape[0], 1, 28, 28).to(device) if s > 0 else torch.zeros(x.shape[0], 1, 28, 28).to(device)
            estimated_noise = model.decoder(x, t)
            estimated_noise = ((1 - alphas[s]) / torch.sqrt(1 - alpha_bars[s])) * estimated_noise
            original_image = x - estimated_noise
            original_image = (1.0 / torch.sqrt(alphas[s])) * original_image
            model_reverse_mean = original_image
            noise_added = torch.sqrt(betas[s]) * z
            x = original_image + noise_added
            beta_bar_t = ((1 - alpha_bars[s - 1]) * betas[s]) / (1 - alpha_bars[s])
            beta_t = betas[s]
            beta_ratio = beta_bar_t / beta_t
            mean_difference = true_reverse_mean - model_reverse_mean
            mean_frobenius_squared = torch.norm(mean_difference.view(mean_difference.shape[0], -1), dim=1)**2
            mean_frobenius_squared = mean_frobenius_squared.mean()
            D_KL = 0.5 * (beta_ratio + (mean_frobenius_squared / (beta_t * n_dim)) - 1.0 - torch.log(beta_ratio))
            D_KL = D_KL / torch.log(torch.tensor(2.0))  # Correct conversion to bits
            rate += D_KL

        # Calculate the validation reconstruction loss
        reconstruction_loss = torch.nn.functional.mse_loss(data, x)
        print(f"Rate for the batch: {rate} bits/dimensions")
        print(f"Distortion for the batch: {reconstruction_loss}")
        total_val_loss += reconstruction_loss
        total_rate += rate

avg_val_loss = total_val_loss / len(val_loader)
avg_rate = total_rate / len(val_loader)

print(f"Average distortion: {avg_val_loss}")
print(f"Average rate: {avg_rate}")

In [None]:

# Save the dictionary to a .pt file
torch.save(sampled_data, 'sampled_gaussian_ddpm.pt')

print("Sampled data saved to 'mnist_gaussian_ddpm.pt'")


In [None]:
import torch
import matplotlib.pyplot as plt

# Load the saved data
sampled_data = torch.load('sampled_gaussian_ddpm.pt')

# Access a specific entry
index = 35# Replace with the index you want to check
sampled_image = sampled_data[index]['sampled']


# Convert the tensor to a numpy array
sampled_image = sampled_image.cpu().numpy().transpose(1, 2, 0)


# Display the sampled image
plt.figure(figsize=(5, 5))
plt.imshow(sampled_image, cmap = 'gray')
plt.title("Sampled Image")
plt.axis('off')
plt.show()
