## Denoising Diffusion Probabilistic Models

In [17]:
!gpustat

'gpustat' is not recognized as an internal or external command,
operable program or batch file.


In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from common_utils.notebook_utils import *

In [14]:
from torch.utils.data import DataLoader
import torchvision
import imageio
import glob
import pytorch_lightning as pl

from models.unet import Unet
from diffusion.diffusion import GaussianDiffusion
from datasets.torchset import TorchSet

### Set model parameters

In [15]:
# Training hyperparameters
diffusion_timesteps = 1000
training_steps = 10
batch_size = 128
# WILL BE USED WHEN MOVING TO SIGNLE IMAGE
#image_name = 'balloons.png'
#tb_logger = pl.loggers.TensorBoardLogger("lightning_logs/", name=image_name)

### Load dataset and train model

In [8]:
# Create datasets and data loaders
#train_dataset = TorchSet(train=True, root='/net/mraid11/export/vision/datasets/', dataset='CIFAR')
train_dataset = TorchSet(train=True, root='./data', dataset='CIFAR')
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

# Create model and trainer
model = Unet(dim = 128, dim_mults = (1, 2, 2, 2))
diffusion = GaussianDiffusion(model, channels=3, timesteps=diffusion_timesteps)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data\cifar-10-python.tar.gz to ./data


In [20]:
# Load Trainer model
tb_logger = pl.loggers.TensorBoardLogger(
    "lightning_logs/",
    name='CIFAR'
)

trainer = pl.Trainer(
    max_steps=training_steps, 
    log_every_n_steps=10, 
    #gpus=1, 
    accelerator='cpu',
    auto_select_gpus=True,
    logger=tb_logger
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Train model
trainer.fit(diffusion, train_loader)

Missing logger folder: lightning_logs/CIFAR

  | Name  | Type | Params
-------------------------------
0 | model | Unet | 40.5 M
-------------------------------
40.5 M    Trainable params
0         Non-trainable params
40.5 M    Total params
162.186   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

### Sample from model

In [None]:
gif_shape = [3, 3]
sample_batch_size = gif_shape[0] * gif_shape[1]
n_hold_final = 10

# Generate samples from denoising process
gen_samples = []
x = torch.randn((sample_batch_size, train_dataset.depth, train_dataset.size, train_dataset.size))
sample_steps = torch.arange(model.t_range-1, 0, -1)
for t in sample_steps:
    x = model.denoise_sample(x, t)
    if t % 50 == 0:
        gen_samples.append(x)
for _ in range(n_hold_final):
    gen_samples.append(x)
gen_samples = torch.stack(gen_samples, dim=0).moveaxis(2, 4).squeeze(-1)
gen_samples = (gen_samples.clamp(-1, 1) + 1) / 2

In [None]:
# Process samples and save as gif
gen_samples = (gen_samples * 255).type(torch.uint8)
gen_samples = gen_samples.reshape(-1, gif_shape[0], gif_shape[1], train_dataset.size, train_dataset.size, train_dataset.depth)

def stack_samples(gen_samples, stack_dim):
    gen_samples = list(torch.split(gen_samples, 1, dim=1))
    for i in range(len(gen_samples)):
        gen_samples[i] = gen_samples[i].squeeze(1)
    return torch.cat(gen_samples, dim=stack_dim)

gen_samples = stack_samples(gen_samples, 2)
gen_samples = stack_samples(gen_samples, 2)

imageio.mimsave(
    f"{trainer.logger.log_dir}/pred.gif",
    list(gen_samples),
    fps=5,
)