# Denoising Diffusion Probabilistic Models (DDPM) on MNIST

In [1]:
# Python imports
import random
import math
import abc
import os

# PyTorch imports
import torch
import torchvision

# Third-party imports
import matplotlib.pyplot as plt
from tabulate import tabulate

# Own imports
import utils

## 0. Constants

In [2]:
# Hyper-parameters.
LEARNING_RATE_START = 5e-4
EPOCHS = 10
BATCH_SIZE = 64
GROUP_NORM_GROUPS = 32

# Diffusion model parameters.
TIMESTAMPS = 500

# Dataset parameters.
WIDTH = 28
HEIGHT = 28
CHANNELS = 1

# Others
SAVED_FILENAME = 'MNIST'

### 0.1. Device

In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

# Print the device to run the model.
print(f"Use device [{DEVICE}].")

Use device [cuda].


## 1. Dataset

In [9]:
# 1. Each item in dataset is a tuple of (torch.Tensor, int). The first tensor is of shape (CHANNELS, HEIGHT, WIDTH) and dtype torch.float32, which is the image.
# 2. I use a custom utils.MNIST to load and cache the dataset, so that reading from hard drive during training is avoided.
dataset = utils.MNIST(
    './data',
    train=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        # torchvision.transforms.Resize((HEIGHT, WIDTH), antialias=True),
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5]),
    ]),
)

# samples = next(iter(dataloader))
# samples is a Python list of len 2.
# samples[0] is a tensor of shape (BATCH_SIZE, CHANNELS, WIDTH, HEIGHT)
# samples[1] is a tensor of shape (BATCH_SIZE,)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Inspect statistics of the dataset.

In [10]:
# Get the size of the dataset.
dataset_size = len(dataset)

# Construct data structures used by tabulate.
headers = ['Property', 'Value']
rows = [
    ['Size', dataset_size],
]

# Print the table.
print(tabulate(rows, headers=headers, tablefmt='github'))

| Property   |   Value |
|------------|---------|
| Size       |   60000 |
