In [1]:
# adapted from https://github.com/Michedev/DDPMs-Pytorch/blob/master/train.py

In [2]:
%load_ext autoreload
%autoreload 2

In [57]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import hydra
from hydra import compose, initialize
from matplotlib import pyplot as plt

from ddpm_tutorial.ddpm.callbacks.ema import EMA
from ddpm_tutorial.ddpm.callbacks.logger import LoggerCallback

In [None]:
# Initialize configuration using Hydra
with initialize(version_base=None, config_path="../../src/ddpm_tutorial/ddpm/config"):
    config = compose(config_name="train")

pl.seed_everything(config.seed)  # Set seed for reproducibility

Seed set to 1337


Path('/home/doughet/diffusion/ddpm-tutorial/docs/notebooks/model')

In [None]:
# Create the variance scheduler and a deep generative model
scheduler = hydra.utils.instantiate(config.scheduler)
opt = hydra.utils.instantiate(config.optimizer)
model: pl.LightningModule = hydra.utils.instantiate(config.model, variance_scheduler=scheduler, opt=opt)

# Create training and validation datasets
train_dataset: Dataset = hydra.utils.instantiate(config.dataset.train)
val_dataset: Dataset = hydra.utils.instantiate(config.dataset.val)

In [None]:
# Create PyTorch dataloaders for the training and validation datasets
pin_memory = "gpu" in config.accelerator
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    pin_memory=pin_memory,
    num_workers=4,
    persistent_workers=True,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    pin_memory=pin_memory,
    num_workers=4,
    persistent_workers=True,
)

# Create a ModelCheckpoint callback that saves the model weights to disk during training
checkpoint_callback = ModelCheckpoint(
    "./",
    "epoch={epoch}-valid_loss={val/loss_epoch}",
    monitor="val/loss_epoch",
    auto_insert_metric_name=False,
    save_last=True,
)
ddpm_logger = LoggerCallback(config.freq_logging, config.freq_logging_norm_grad, config.batch_size_gen_images)
callbacks = [checkpoint_callback, ddpm_logger]
# Add additional callbacks
if config.ema:
    callbacks.append(EMA(config.ema_decay))
if config.early_stop:
    callbacks.append(EarlyStopping("val/loss_epoch", min_delta=config.min_delta, patience=config.patience))

# Create a PyTorch Lightning Trainer
trainer = pl.Trainer(
    callbacks=callbacks,
    accelerator=config.accelerator,
    devices=config.devices if config.devices is not None else "auto",
    gradient_clip_val=config.gradient_clip_val,
    gradient_clip_algorithm=config.gradient_clip_algorithm,
)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.


/usersoftware/peerd/doughet/.conda/envs/ddpm-tutorial/lib/python3.12/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /usersoftware/peerd/doughet/.conda/envs/ddpm-tutoria ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [31]:
trainer.fit(model, train_dataloader, val_dataloader)

/usersoftware/peerd/doughet/.conda/envs/ddpm-tutorial/lib/python3.12/site-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/usersoftware/peerd/doughet/.conda/envs/ddpm-tutorial/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /home/doughet/diffusion/ddpm-tutorial/docs/notebooks exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name            | Type         | Params | Mode 
---------------------------------------------------------
0 | denoiser_

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usersoftware/peerd/doughet/.conda/envs/ddpm-tutorial/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/usersoftware/peerd/doughet/.conda/envs/ddpm-tutorial/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# Sample

In [None]:
model.to("cuda")

device(type='cuda', index=0)

In [None]:
X_noise = model.generate(T=1000, get_intermediate_steps=True)

In [130]:
interval = 5
every_nth_image = 2
f = plt.figure()
plt.axis("off")
im = plt.imshow(X_noise[0][0, 0].cpu().detach(), cmap="gray", vmin=0, vmax=1)


def update(i):
    if i >= len(X_noise):
        i = len(X_noise) - 1
    arr = X_noise[i][0, 0].cpu().detach()
    if i < len(X_noise) - 1:
        arr = (arr / 2 + 0.5).clamp(0, 1)
    im.set_array(arr)
    return [im]


anim = animation.FuncAnimation(
    f,
    update,
    frames=np.arange(0, int(len(X_noise) * 1.2), every_nth_image),
    interval=1,
    repeat=True,
    blit=True,
)
anim.save("generation.gif", writer="pillow")

from IPython.display import HTML

HTML('<img src="generation.gif" width="400">')

<IPython.core.display.Javascript object>