## Denoising Diffusion Probabilistic Models

In [1]:
import torch
from data import DiffSet
import pytorch_lightning as pl
from model import DiffusionModel
from torch.utils.data import DataLoader
import imageio
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#pip install pytorch_lightning  imageio

### Set model parameters

In [3]:
# Training hyperparameters
diffusion_steps = 1000
dataset_choice = "CIFAR"
#dataset_choice = "MNIST"
#dataset_choice = "FashionMNIST"
max_epoch = 10
batch_size = 128

# Loading parameters
load_model = False
load_version_num = 1

### Load dataset and train model

In [4]:
# Code for optionally loading model
pass_version = None
last_checkpoint = None

if load_model:
    pass_version = load_version_num
    last_checkpoint = glob.glob(
        f"./lightning_logs/{dataset_choice}/version_{load_version_num}/checkpoints/*.ckpt"
    )[-1]

In [5]:
# Create datasets and data loaders
train_dataset = DiffSet(True, dataset_choice)
val_dataset = DiffSet(False, dataset_choice)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

# Create model and trainer
if load_model:
    model = DiffusionModel.load_from_checkpoint(last_checkpoint, in_size=train_dataset.size*train_dataset.size, t_range=diffusion_steps, img_depth=train_dataset.depth)
else:
    model = DiffusionModel(train_dataset.size*train_dataset.size, diffusion_steps, train_dataset.depth)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# Load Trainer model
tb_logger = pl.loggers.TensorBoardLogger(
    "lightning_logs/",
    name=dataset_choice,
    version=pass_version,
)

trainer = pl.Trainer(
    max_epochs=max_epoch, 
    log_every_n_steps=10, 
    gpus=1, 
    auto_select_gpus=True,
    resume_from_checkpoint=last_checkpoint, 
    logger=tb_logger
)

  rank_zero_deprecation(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Train model
trainer.fit(model, train_loader, val_loader)

Missing logger folder: lightning_logs/CIFAR
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name  | Type       | Params
--------------------------------------
0  | inc   | DoubleConv | 38.8 K
1  | down1 | Down       | 295 K 
2  | down2 | Down       | 1.2 M 
3  | down3 | Down       | 2.4 M 
4  | up1   | Up         | 6.2 M 
5  | up2   | Up         | 1.5 M 
6  | up3   | Up         | 406 K 
7  | outc  | OutConv    | 195   
8  | sa1   | SAWrapper  | 395 K 
9  | sa2   | SAWrapper  | 395 K 
10 | sa3   | SAWrapper  | 99.6 K
--------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.681    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Epoch 0:  83%|████████▎ | 391/470 [01:54<00:23,  3.43it/s, loss=0.0732, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/79 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/79 [00:00<?, ?it/s][A
Epoch 0:  83%|████████▎ | 392/470 [01:54<00:22,  3.43it/s, loss=0.0732, v_num=0]
Epoch 0:  84%|████████▎ | 393/470 [01:54<00:22,  3.43it/s, loss=0.0732, v_num=0]
Epoch 0:  84%|████████▍ | 394/470 [01:54<00:22,  3.44it/s, loss=0.0732, v_num=0]
Epoch 0:  84%|████████▍ | 395/470 [01:54<00:21,  3.44it/s, loss=0.0732, v_num=0]
Epoch 0:  84%|████████▍ | 396/470 [01:54<00:21,  3.45it/s, loss=0.0732, v_num=0]
Epoch 0:  84%|████████▍ | 397/470 [01:54<00:21,  3.46it/s, loss=0.0732, v_num=0]
Epoch 0:  85%|████████▍ | 398/470 [01:54<00:20,  3.46it/s, loss=0.0732, v_num=0]
Epoch 0:  85%|████████▍ | 399/470 [01:55<00:20,  3.47it/s, loss=0.0732, v_num=0]
Epoch 0:  85%|████████▌ | 400/470 [01:55<00:20,  3.47it/s, loss=0.0732, v_num=0]
Epoch 0:  85%|████████▌ | 401/470 [

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 470/470 [02:10<00:00,  3.60it/s, loss=0.0409, v_num=0]


### Sample from model

In [9]:
gif_shape = [3, 3]
sample_batch_size = gif_shape[0] * gif_shape[1]
n_hold_final = 10

# Generate samples from denoising process
gen_samples = []
x = torch.randn((sample_batch_size, train_dataset.depth, train_dataset.size, train_dataset.size))
sample_steps = torch.arange(model.t_range-1, 0, -1)
for t in sample_steps:
    x = model.denoise_sample(x, t)
    if t % 50 == 0:
        gen_samples.append(x)
for _ in range(n_hold_final):
    gen_samples.append(x)
gen_samples = torch.stack(gen_samples, dim=0).moveaxis(2, 4).squeeze(-1)
gen_samples = (gen_samples.clamp(-1, 1) + 1) / 2

In [10]:
# Process samples and save as gif
gen_samples = (gen_samples * 255).type(torch.uint8)
gen_samples = gen_samples.reshape(-1, gif_shape[0], gif_shape[1], train_dataset.size, train_dataset.size, train_dataset.depth)

def stack_samples(gen_samples, stack_dim):
    gen_samples = list(torch.split(gen_samples, 1, dim=1))
    for i in range(len(gen_samples)):
        gen_samples[i] = gen_samples[i].squeeze(1)
    return torch.cat(gen_samples, dim=stack_dim)

gen_samples = stack_samples(gen_samples, 2)
gen_samples = stack_samples(gen_samples, 2)

imageio.mimsave(
    f"{trainer.logger.log_dir}/pred.gif",
    list(gen_samples),
    fps=5,
)