This tutorial is associated to the following blog post:
[Conditional model](https://website.vincent-roger.fr/blog/deeplearning/python/2024/06/16/diffusers-conditional_model.html).

Follow it to have more explanations and context.

In [1]:
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler
from diffusers.models import UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from torch import GradScaler, autocast, nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from os.path import exists

from itertools import chain

from dlai_lib.diffusion_utilities import CustomDataset, transform
from diffusers_tutorials.tools.plotly import plot_generated_images

  from .autonotebook import tqdm as notebook_tqdm


# Hyperparameters

In [2]:
batch_size = 128
num_epochs = 32
pre_trained_unet = "weights/conditional_unet.pth"
pre_trained_emb_net = "weights/emb_net.pth"

# Dataset

In [3]:
dataset = CustomDataset(
    "./dlai_lib/sprites_1788_16x16.npy",
    "./dlai_lib/sprite_labels_nc_1788_16x16.npy",
    transform,
    null_context=False,
)

dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


# Model definition

In [4]:
class UnsqueezeLayer(nn.Module):
    """Generic layer to unsqueeze its input."""

    def __init__(self, dim: int) -> None:
        super(UnsqueezeLayer, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.unsqueeze(x, dim=self.dim)


# We have to create this custom class to be able to use our sequential model inside our pipeline.
class CustomSequential(nn.Sequential):
    """Extend sequential to add `device` and `dtype` properties.

    It supposes that all parameters shares the same device and uses the same dtype.
    """

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

In [5]:
num_classes = 5
class_emb_size = 64

emb_net = CustomSequential(
    nn.Linear(num_classes, class_emb_size),
    nn.GELU(),
    nn.Linear(class_emb_size, class_emb_size),
    UnsqueezeLayer(dim=1),
)

# Define the UNet model
unet = UNet2DConditionModel(
    encoder_hid_dim=class_emb_size,
    sample_size=16,  # Input image size
    in_channels=3,  # Number of input channels (e.g., 3 for RGB)
    out_channels=3,  # Number of output channels
    layers_per_block=2,  # Layers per block in the UNet
    block_out_channels=(64, 128),  # Channels in each block
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),  # Types of down blocks
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),  # Types of up blocks
)

In [6]:
# Define the DDPM scheduler
noise_scheduler = DDPMScheduler(500)

In [7]:
unet = unet.to("cuda" if torch.cuda.is_available() else "cpu")
class_net = emb_net.to("cuda" if torch.cuda.is_available() else "cpu")

# Training

In [8]:
def train(
    unet: UNet2DConditionModel,
    emb_net: nn.Module,
    noise_scheduler: DDPMScheduler,
    dataloader: DataLoader,
    num_epochs: int,
    lr: float,
):
    epochs = range(num_epochs)
    losses = np.zeros(num_epochs)

    optimizer = Adam(chain(unet.parameters(), emb_net.parameters()), lr=lr)
    scaler = GradScaler("cuda" if torch.cuda.is_available() else "cpu")  # For mixed precision
    unet.train()

    for epoch in epochs:
        epoch_loss = 0

        for batch in tqdm(dataloader):
            optimizer.zero_grad()

            # Assuming your dataloader provides images and associated target
            images, labels = batch
            images = images.to(unet.device, non_blocking=True)
            labels = labels.to(dtype=torch.float32, device=unet.device, non_blocking=True)

            with autocast("cuda" if torch.cuda.is_available() else "cpu"):  # Mixed precision
                # Generate random noise
                noise = torch.randn(images.shape, device=unet.device)

                # Generate random timesteps and apply the noise scheduler
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (images.shape[0],),
                    device=unet.device,
                )

                noisy_images = noise_scheduler.add_noise(images, noise, timesteps)

                # Compute the class embeddings
                enc_labels = emb_net(labels)

                # Forward pass through the model with labels embeddings
                predicted_noise = unet(
                    noisy_images, timesteps, enc_labels, class_labels=labels
                ).sample

                # Compute loss (mean squared error between actual and predicted noise)
                loss = torch.nn.functional.mse_loss(predicted_noise, noise)

            # Backward pass and optimization
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(dataloader)
        losses[epoch] = epoch_loss
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}")

    return losses

In [9]:
if not exists(pre_trained_unet) and not exists(pre_trained_emb_net):
    train(unet, emb_net, noise_scheduler, dataloader, num_epochs, lr=1e-4)
    unet.save_pretrained(pre_trained_unet)
    torch.save(emb_net.state_dict(), pre_trained_emb_net)
else:
    unet = UNet2DConditionModel.from_pretrained(pre_trained_unet)
    emb_net.load_state_dict(torch.load(pre_trained_emb_net))

# Load the model and try it

In [10]:
class ConditionalDDPMPipeline(DDPMPipeline):
    def __init__(
        self, unet: UNet2DConditionModel, class_net: CustomSequential, scheduler: DDPMScheduler
    ) -> None:
        super().__init__(unet=unet, scheduler=scheduler)
        self.class_net = class_net
        self.register_modules(class_net=class_net)

    @torch.no_grad()
    def __call__(
        self,
        class_label: list[list[float]],
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        num_inference_steps: int = 1000,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
        The call function to the pipeline for generation.

        Args:
            class_label (list[list[float]]):
                list of one-hot examples. len(class_label) represents the number of examples to generate.
            generator (`torch.Generator`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            num_inference_steps (`int`, *optional*, defaults to 1000):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        batch_size = len(class_label)
        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        if self.device.type == "mps":
            # randn does not work reproducibly on mps
            image = randn_tensor(image_shape, generator=generator)
            image = image.to(self.device)
        else:
            image = randn_tensor(image_shape, generator=generator, device=self.device)

        labels = torch.tensor(class_label, device=self.device)
        enc_labels = self.class_net(labels)

        # set step values
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):
            # 1. predict noise model_output
            model_output = self.unet(image, t, enc_labels, class_labels=labels, return_dict=False)[
                0
            ]

            # 2. compute previous image: x_t -> x_t-1
            image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

In [11]:
# Create the DDPM pipeline
pipeline = ConditionalDDPMPipeline(unet=unet, class_net=class_net, scheduler=noise_scheduler)
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

ConditionalDDPMPipeline {
  "_class_name": "ConditionalDDPMPipeline",
  "_diffusers_version": "0.27.2",
  "class_net": [
    "__main__",
    "CustomSequential"
  ],
  "scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ]
}

In [12]:
generated_image = pipeline(
    [
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
    ],
    num_inference_steps=100,
)

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 100/100 [00:01<00:00, 85.50it/s] 


In [13]:
fig = plot_generated_images(generated_image.images, 5, 4)
fig.show()