# Imports

Make sure you have the packages installed. You might need to pip install some of them, the console should tell you if you are missing any packages when you try to run the code.

In [None]:
# standard imports
import random
import math
import numpy as np
import matplotlib.pyplot as plt

# allows flexible tensor manipulation
import einops

# progress bars during training and sampling
from tqdm.auto import tqdm

# PyTorch for general machine learning
import torch
import torch.nn as nn
from torch.optim import Adam

# help in displaying images
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image

# libraries for data and dataset processing
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import FashionMNIST

# libraries for quantative metrics calculations
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics
import os

# Helper function to display images

In [None]:
def display_images(images, labels=None, title="", num_samples=20, cols=4):
    # ensure we don't exceed the number of available images
    images = images[:min(num_samples, len(images))]

    # create a grid of images
    # normalize each image
    grid = make_grid(images, nrow=cols, normalize=True, scale_each=True)  # Adjust grid columns
    
    # convert the grid to a PIL image
    grid_img = to_pil_image(grid)
    
    # plot
    plt.figure(figsize=(12, 12))  # You can adjust the figure size as needed
    plt.imshow(grid_img, cmap="gray")
    plt.title(title, fontsize=20)
    plt.axis('off')
    
    # if labels are provided, display them (note: labels arent displayed very well)
    if labels is not None:
        num_images = len(images)
        rows = (num_images + cols - 1) // cols  # Calculate the number of rows in the grid
        for i, label in enumerate(labels[:num_images]):
            plt.text(
                (i % cols) * grid_img.width / cols, 
                (i // cols + 1) * grid_img.height / rows - 10,  # Adjust text position
                label, 
                horizontalalignment='center',
                fontsize=10,
                color='white',
                weight='bold'
            )
    plt.show()

# Pre-process data (and display FashionMNIST images)

Here we can have a look at what some of the real FashionMNIST images look like, allowing us to get a sense of the types of images we will be generating

In [None]:
# transform to tensor and normalize [-1,1]
transform = Compose([
    ToTensor(),
    Lambda(lambda x: (x - 0.5) * 2)
])

# we utilize a batch_size of 128
batch_size = 128

# load FashionMNIST dataset from torchvision and normalize it [-1,1]
train_dataset = FashionMNIST("../data", download=True, train=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# initialize label names for display in relation to their class label
label_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

#get the next batch of images
images, labels = next(iter(train_loader))

# convert labels to their corresponding names for display
label_texts = [label_names[label] for label in labels]

# display images with labels
display_images(images, labels=None, title=None, num_samples=50, cols=10)

# Unconditional DDPM implementation

In [None]:
# get the device (GPU if available)
# this is essential as we want to train on our GPU!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device Type: {device} " + (f"| Name: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

In [None]:
# variance scheduler class to pre-compute noise values
class VarianceScheduler:
    def __init__(self, beta1, beta2, T, device, schedule_type="linear", s=0.008, beta_max=0.999):
        # make sure beta values are in defined bounds
        assert 0 < beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
        self.device = device
        self.schedule_type = schedule_type
        self.beta1 = beta1 # start beta
        self.beta2 = beta2 # end beta
        self.T = T # total timesteps in diffusion process
        self.s = s # smoothing constant for cosine schedule
        self.beta_max = beta_max # capped beta value for cosine schedule
        self.schedule = self.compute_schedule()

    def compute_schedule(self):
        timesteps = torch.arange(0, self.T, dtype=torch.float32).to(self.device) # define timesteps 0 to T and move to gpu
        if self.schedule_type == 'linear':
            # compute linear schedule from Ho et al.'s DDPM paper
            betas = torch.linspace(self.beta1, self.beta2, self.T).to(self.device) 
        elif self.schedule_type == 'cosine':
            # compute cosine schedule using equations from Nichol et al.'s Improved DDPM paper
            t_scaled = (timesteps / self.T + self.s) / (1 + self.s) * (torch.pi / 2)
            f_t = torch.cos(t_scaled).pow(2)
            f_0 = torch.cos(torch.tensor(self.s / (1 + self.s) * (torch.pi / 2))).pow(2)
            alpha_bars = f_t / f_0
            alpha_bars_prev = torch.cat([torch.tensor([1]).to(self.device), alpha_bars[:-1]])
            betas = 1 - alpha_bars / alpha_bars_prev
            betas = torch.clip(betas, 0, self.beta_max)
        else:
            raise ValueError("Unknown schedule type: {}".format(self.schedule_type)) # error for unknown schedule type

        # calculate alphas and alpha_bars as shown in DDPM paper (reparameterization trick)
        alphas = 1 - betas
        alpha_bars = torch.tensor([torch.prod(alphas[:i + 1]) for i in range(len(alphas))]).to(self.device)

        # return dictionary of betas, alphas, and alpha_bars
        schedule = {
            "betas": betas,
            "alphas": alphas,
            "alpha_bars": alpha_bars,
        }
        return schedule

In [None]:
# plot the alpha bars of the cosine and linear schedulers to check if they work correctly
def plot_alpha_bars(schedulers):
    plt.figure(figsize=(10, 5))
    for scheduler in schedulers:
        alpha_bars = scheduler.schedule['alpha_bars'].cpu().numpy()
        timesteps = torch.arange(0, scheduler.T, dtype=torch.float32).cpu().numpy()
        plt.plot(timesteps, alpha_bars, label=f'{scheduler.schedule_type.capitalize()} Scheduler')

    plt.title('Comparison of $\\alpha_{\\bar{t}}$ over Time for Different Schedulers')
    plt.xlabel('Diffusion Timestep')
    plt.ylabel('$\\alpha_{\\bar{t}}$')
    plt.grid(True)
    plt.legend()
    plt.show()

# initialize schedulers with values from relevant literature
scheduler_linear = VarianceScheduler(0.0001, 0.02, 1000, device, 'linear')
scheduler_cosine = VarianceScheduler(0.0001, 0.02, 1000, device, 'cosine')

plot_alpha_bars([scheduler_linear, scheduler_cosine]) # plot

In [None]:
# DDPM class (contains our noising, and sampling functions)
class DDPM(nn.Module):
    def __init__(self, network, T=1000, beta1=10 ** -4, beta2=0.02, schedule_type='linear', device=None):
        super(DDPM, self).__init__()
        self.device = device
        self.T = T # timesteps T
        self.variance_scheduler = VarianceScheduler(beta1, beta2, T, device, schedule_type) # initialize variance schdeuler
        self.network = network.to(device) # move the unet to the gpu
        
        # use register_buffer to store each schedule component
        for k, v in self.variance_scheduler.schedule.items():
            self.register_buffer(k, v)

    # noising function (forward diffusion)
    def forward(self, x0, t, noise=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, channel, height, width = x0.shape
        a_bar = self.alpha_bars[t]

        #if no noise is passed then calculate a random noise mask of the same dimensions as the image
        if noise is None:
            noise = torch.randn(n, channel, height, width).to(self.device) # move to gpu

        # noising function
        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * noise
        # return noised image
        return noisy

    # run each image through the unet for each timestep t in vector t.
    def backward(self, x, t):
        return self.network(x, t)# the unet returns its prediction of noise added at timestep t

    # sampling function that returns n_samples images
    def sample_images(self, n_samples=16):
        self.network.eval()  # switch unet to eval mode (stops training behaviors)

        # torch.no_grad() turns off gradient computation (not needed when sampling) for reduced memory usage and faster computations
        with torch.no_grad():
            # init image tensor with random noise. shape = (N,1,28,28)
            x = torch.randn(n_samples, *(1,28,28)).to(self.device) # move to gpu
            
            # loop backward through entire diffusion process (T-1 -> 0)
            for i in tqdm(reversed(range(1, self.T)), position=0):
                # create tensor filled with the current timestep, shaped for each sample
                t = (torch.ones(n_samples) * i).long().to(self.device)
                
                # predict noise added
                predicted_noise = self.backward(x, t)
                
                # get alpha, alpha_bar, and beta values for the current timestep from the precomputed schedule
                alpha_t = self.alphas[i]
                alpha_t_bar = self.alpha_bars[i]
                beta_t = self.betas[i]
                
                # determine noise to be added at this timestep
                if i > 1:
                    # sample random noise if not at the final timestep
                    noise = torch.randn_like(x)
                else:
                    # use no noise at the final timestep to avoid adding unnecessary randomness
                    noise = torch.zeros_like(x)
                
                # update image tensor using reverse diffusion formula
                x = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_t_bar)) * predicted_noise) + torch.sqrt(beta_t) * noise
            
        # reset unet to training mode
        self.network.train()
        
        # return tensor containing image samples
        return x

# U-Net
The code for our unconditional U-Net heavily inspired by existing GitHub Implementation (Cited in dissertation report)

In [None]:
# U-Net helpers

def sinusoidal_embedding(n, dim):
    # returns the standard positional embedding
    if not isinstance(n, int) or not isinstance(dim, int) or n < 1 or dim < 1:
        raise ValueError("both 'n' and 'dim' must be positive integers!")
    
    # vectorized calculation of the frequency terms
    wk = torch.pow(10_000, -torch.arange(0, dim, 2, dtype=torch.float32) / dim)
    wk = wk.reshape(1, -1)
    # calculate positional information
    t = torch.arange(n, dtype=torch.float32).reshape(n, 1)
    
    # create the embedding matrix
    embedding = torch.zeros(n, dim)
    embedding[:, 0::2] = torch.sin(t * wk)
    embedding[:, 1::2] = torch.cos(t * wk)

    return embedding

# runs the sinusoidal time embedding through a MLP altering its dimensions to fit the respective unet layer
def make_time_embedding(time_emb_dim, out_features):
    return nn.Sequential(
        nn.Linear(time_emb_dim, out_features),
        nn.SiLU(),
        nn.Linear(out_features, out_features)
    )

# a small convolutional block that applies 2 convolutional operations with options for normalization and specifying activation
class ConvBlock(nn.Module):
    def __init__(self, shape, in_channels, out_channels, kernel=3, stride=1, padding=1, activation=None, normalize=True):
        super(ConvBlock, self).__init__()
        self.layernorm = nn.LayerNorm(shape)
        self.convolution1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding)
        self.convolution2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.layernorm(x) if self.normalize else x
        out = self.convolution1(out)
        out = self.activation(out)
        out = self.convolution2(out)
        out = self.activation(out)
        return out

In [None]:
class UNet(nn.Module):
    def __init__(self, T=1000, time_emb_dim=100):
        super(UNet, self).__init__()

        # initialize sinusoidal time embedding
        self.time_embed = nn.Embedding(T, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(T, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # make time embeddings for every layer
        self.time_embedding1 = make_time_embedding(time_emb_dim, 1)
        self.time_embedding2 = make_time_embedding(time_emb_dim, 10)
        self.time_embedding3 = make_time_embedding(time_emb_dim, 20)
        self.time_embedding_mid = make_time_embedding(time_emb_dim, 40)
        self.time_embedding4 = make_time_embedding(time_emb_dim, 80)
        self.time_embedding5 = make_time_embedding(time_emb_dim, 40)
        self.time_embedding_out = make_time_embedding(time_emb_dim, 20)

        # first residual convolutional block, made up of 3 convolutional blocks
        # increases our feature depth from 1 to 10
        self.residual1 = nn.Sequential(
            ConvBlock((1, 28, 28), 1, 10),
            ConvBlock((10, 28, 28), 10, 10),
            ConvBlock((10, 28, 28), 10, 10)
        )
        
        # first downsampling, decreases image dimensions to (14x14)
        self.downsample1 = nn.Conv2d(10, 10, 4, 2, 1)

        # second residual convolutional block, increases feature depth to 20
        self.residual2 = nn.Sequential(
            ConvBlock((10, 14, 14), 10, 20),
            ConvBlock((20, 14, 14), 20, 20),
            ConvBlock((20, 14, 14), 20, 20)
        )

        # second downsampling, decreases image dimensions to 7x7
        self.downsample2 = nn.Conv2d(20, 20, 4, 2, 1)

        # third residual convolutional block, increases feature depth to 40
        self.residual3 = nn.Sequential(
            ConvBlock((20, 7, 7), 20, 40),
            ConvBlock((40, 7, 7), 40, 40),
            ConvBlock((40, 7, 7), 40, 40)
        )

        # third downsampling decreases spatial dimensions to 3x3
        self.downsample3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # bottleneck, decreases feature depth to 20 and then increases back to 40
        self.bottleneck = nn.Sequential(
            ConvBlock((40, 3, 3), 40, 20),
            ConvBlock((20, 3, 3), 20, 20),
            ConvBlock((20, 3, 3), 20, 40)
        )

        # first upsampling, increase the spatial dimensions to 7x7
        self.upsample1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        # fourth residual convolutional block, decreases feature depth from 80 to 20
        # the feature depth is 80 as we concatenate the correponding downsampling layers data to this layer
        # via skip connection
        self.residual4 = nn.Sequential(
            ConvBlock((80, 7, 7), 80, 40),
            ConvBlock((40, 7, 7), 40, 20),
            ConvBlock((20, 7, 7), 20, 20)
        )

        # second upsampling, increase the spatial dimensions to 14x14
        self.upsample2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)

        # fifth residual convolutional block, decreases feature depth from 40 to 20
        # the feature depth 40 due to concatenation from downsampling stage
        self.residual5 = nn.Sequential(
            ConvBlock((40, 14, 14), 40, 20),
            ConvBlock((20, 14, 14), 20, 10),
            ConvBlock((10, 14, 14), 10, 10)
        )

        # third upsampling, increase spatial dimensions to 29x29
        self.upsample3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.residual_out = nn.Sequential(
            ConvBlock((20, 28, 28), 20, 10),
            ConvBlock((10, 28, 28), 10, 10),
            ConvBlock((10, 28, 28), 10, 10, normalize=False)
        )

        # final convolution for (10,28,28) -> (1,28,28)
        self.final_convolution = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        time_embedding = self.time_embed(t) #time embedding
        n = len(x) # number of images
        # note how we integrate the time embedding into the flow
        out1 = self.residual1(x + self.time_embedding1(time_embedding).reshape(n, -1, 1, 1))  # (10, 28, 28)
        out2 = self.residual2(self.downsample1(out1) + self.time_embedding2(time_embedding).reshape(n, -1, 1, 1))  # (20, 14, 14)
        out3 = self.residual3(self.downsample2(out2) + self.time_embedding3(time_embedding).reshape(n, -1, 1, 1))  # (40, 7, 7)

        out_mid = self.bottleneck(self.downsample3(out3) + self.time_embedding_mid(time_embedding).reshape(n, -1, 1, 1))  # (40, 3, 3)

        # skip connection
        out4 = torch.cat((out3, self.upsample1(out_mid)), dim=1)  # (80, 7, 7)
        out4 = self.residual4(out4 + self.time_embedding4(time_embedding).reshape(n, -1, 1, 1))  # (20, 7, 7)

        # skip connection
        out5 = torch.cat((out2, self.upsample2(out4)), dim=1)  # (40, 14, 14)
        out5 = self.residual5(out5 + self.time_embedding5(time_embedding).reshape(n, -1, 1, 1))  # (10, 14, 14)

        # skip connection
        out = torch.cat((out1, self.upsample3(out5)), dim=1)  # (20, 28, 28)
        out = self.residual_out(out + self.time_embedding_out(time_embedding).reshape(n, -1, 1, 1))  # (1, 28, 28)

        out = self.final_convolution(out)

        return out

In [None]:
# testing that the unet outputs accurate information
test_unet = UNet()
sample_input = torch.randn(1, 1, 28, 28)  # example input tensor
time_step = torch.randint(0, 1000, (1,))  # random time step
output = test_unet(sample_input, time_step)
print("Output shape:", output.shape)  # check output shape
# display_images(output, labels=None, title="", num_samples=1, cols=1)

# Training Loop

In [None]:
def training_loop(ddpm, dataloader, n_epochs, optim, device, sample_images=True, store="ddpm.pt"):
    T = ddpm.T
    mse = nn.MSELoss() #loss
    best_loss = float("inf") # we initialize a best loss tracker
    losses = [] # a store for losses to plot

    for epoch in tqdm(range(n_epochs), desc="Training progress"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(dataloader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}")):
            # load data
            x0 = batch[0].to(device) 
            n = len(x0)

            # pick some noise for each of the images in the batch, a timestep
            eta = torch.randn_like(x0).to(device)
            t = torch.randint(0, T, (n,)).to(device)

            # compute the noisy image based on the time-step (forward process)
            noisy_imgs = ddpm(x0, t, eta)

            # getting model estimation of noise based on the images and the time-step
            eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))

            # optimize the MSE between the actual noise and the predicted noise
            loss = mse(eta_theta, eta)
            optim.zero_grad()
            loss.backward()
            optim.step()
            # calculates average loss over epoch
            epoch_loss += loss.item() * len(x0) / len(dataloader.dataset)
        # store epoch loss for plotting
        losses.append(epoch_loss)

        # if sample_images then, display images generated at this epoch
        if sample_images:
            # sample 10 images
            generated_images = ddpm.sample_images(n_samples=10)
            
            # display the generated images using the helper function
            display_images(generated_images, labels=None, title=f"Images generated at epoch {epoch + 1}", num_samples=10, cols=10)
        
        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # store the model if we attain  a new low loss
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store)
            log_string += " model stored"

        print(log_string) # print epoch loss and whether model was stored

    # plotting the loss
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, n_epochs + 1), losses, marker='o')
    plt.title("Training Loss Per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

In [None]:
# defining model
# values defined as originally suggested by Ho et al. in DDPM paper
T = 1000
beta1 = 10 ** -4
beta2 = 0.02
schedule_type = "linear"

# initialize our ddpm for training
ddpm = DDPM(UNet(T), T=T, beta1=beta1, beta2=beta2, schedule_type=schedule_type, device=device)

In [None]:
# training loop
# we save the best loss model at best_loss_ddpm.pt and the final 300 epoch trained model at ddpm_300.pt
store = "best_loss_ddpm.pt"
epochs = 300
lr = 0.001

training_loop(ddpm, train_loader, epochs, optim=Adam(ddpm.parameters(), lr), device=device, sample_images=True, store=store)
torch.save(ddpm.state_dict(), 'ddpm.pt')

# Sampling (Optionally Loading pretrained model weights)

In [None]:
# Optionally load a pre-trained model here.
# T = 1000
# beta1 = 0.0001
# beta2 = 0.02
# schedule_type = "linear"
# model = DDPM(UNet(T), T=T, beta1=beta1, beta2=beta2, schedule_type=schedule_type,  device=device)
# model.load_state_dict(torch.load(model path))
# model.eval()

newly_generated_images = ddpm.sample_images(n_samples=100)
display_images(newly_generated_images, labels=None, title="newly generated images", num_samples=100, cols=10)

# Quantitative Metrics Calculation (FID)

In [None]:
def save_images(images, folder="saved_images"):
    # takes a tensor of images and saves them in a specified folder
    if not isinstance(images, torch.Tensor):
        raise ValueError("Images should be a PyTorch tensor")

    images = images.detach().cpu()

    # normalize and prepare images
    os.makedirs(folder, exist_ok=True)
    for i, img_tensor in enumerate(images):
        img = img_tensor.squeeze()  # remove color channels
        img_np = img.numpy()
        plt.imsave(os.path.join(folder, f'image_{i}.png'), img_np, cmap='gray')

def compute_fid(real_images_path, fake_images_path):
    # computes the FID score between two sets of images located at the given folder paths.
    # we set isc and kid to false as we only want to calculate FID scores
    # we set samples_find_deep=True as we need to find images recursively within our folders
    metrics = calculate_metrics(input1=real_images_path, input2=fake_images_path, cuda=True, isc=False, fid=True, kid=False, samples_find_deep=True)
    return metrics['frechet_inception_distance']

In [None]:
# Optionally load a pre-trained model here.
# T = 1000
# beta1 = 0.0001
# beta2 = 0.02
# schedule_type = "linear"
# model = DDPM(UNet(T), T=T, beta1=beta1, beta2=beta2, schedule_type=schedule_type, device=device)
# model.load_state_dict(torch.load(model_path))
# model.eval()

# code to iteratively generate 50k images.
# loop 5 times generating 10k each time storing in separate subfolders
n_samples = 10000
n_loops = 5
folder_name = "generated_images"

# generate (n_loops * n_samples) images
for loop in range(n_loops):
    # generate n_samples images using your model
    generated_images = ddpm.sample_images(n_samples=n_samples)

    # folder to save images in a loop-specific subfolder
    loop_folder = os.path.join("generated_images", f'set_{loop + 1}')
    save_images(generated_images, loop_folder)

In [None]:
# load n_samples of the FashionMNIST images and save them to folder real_images
dataset = FashionMNIST(root="../data", train=True, transform=ToTensor(), download=True)
real_images = torch.stack([dataset[i][0] for i in range(n_loops * n_samples)]) # stack real images to create a single tensor
save_images(real_images, "real_images")

In [None]:
# compute and print the FID score of synthetic vs authentic images
fid_score = compute_fid("real_images", "generated_images")
print(f"FID Score: {fid_score}")