This tutorial is associated to the following blog post:
[Improve the conditional model](https://website.vincent-roger.fr/blog/deeplearning/python/2024/09/08/diffusers-obtain-better-results.html).

Follow it to have more explanations and context.

In [1]:
from itertools import chain
from math import ceil
from os.path import exists
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from diffusers import DDIMPipeline, DDIMScheduler
from diffusers.models import UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from torch import GradScaler, autocast, nn
from torch._C import device, dtype
from torch.optim import AdamW
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm

from diffusers_tutorials.datasets import SpritesDataset
from diffusers_tutorials.tools.plotly import plot_generated_images

  from .autonotebook import tqdm as notebook_tqdm


# Hyperparameters

In [2]:
batch_size = 128
num_epochs = 100
warmup = 80
pre_trained_unet = "weights/better_conditional_unet.pth"
pre_trained_emb_net = "weights/better_emb_net.pth"

# Dataset

In [3]:
dataset = SpritesDataset(
    "./dlai_lib/sprites_1788_16x16.npy",
    "./dlai_lib/sprite_labels_nc_1788_16x16.npy",
    null_context=False,
    clean_version=True
)

# Create dataset sampler to compensate the unbalanced dataset
labels = dataset.slabels.argmax(axis=1)
u_labels, class_counts = np.unique(labels, return_counts=True)
class_weights = 1 - class_counts / class_counts.sum()
sample_weights = tuple(class_weights[label] for label in labels)
num_samples = ceil(len(labels)/batch_size) * batch_size # num_samples is a multiple of batch_size to avoid recompiling

dataset_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

dataloader = DataLoader(
    dataset, sampler=dataset_sampler, batch_size=batch_size, num_workers=1, pin_memory=True
)

# Model definition

In [4]:
class UnsqueezeLayer(nn.Module):
    """Generic layer to unsqueeze its input."""

    def __init__(self, dim: int) -> None:
        super(UnsqueezeLayer, self).__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.unsqueeze(x, dim=self.dim)


# We have to create this custom class to be able to use our sequential model inside our pipeline.
class CustomSequential(nn.Sequential):
    """Extend sequential to add `device` and `dtype` properties.

    It supposes that all parameters shares the same device and uses the same dtype.
    """

    @property
    def device(self) -> device:
        return next(self.parameters()).device

    @property
    def dtype(self) -> dtype:
        return next(self.parameters()).dtype

In [5]:
num_classes = 5
class_emb_size = 64

emb_net = CustomSequential(
    nn.Linear(num_classes, class_emb_size//2), # bottleneck to force better embeddings quality
    nn.Dropout(0.1),
    nn.GELU(),
    nn.Linear(class_emb_size//2, class_emb_size),
    UnsqueezeLayer(dim=1),
)

# Define the UNet model
unet = UNet2DConditionModel(
    encoder_hid_dim=class_emb_size,
    sample_size=16,  # Input image size
    in_channels=3,  # Number of input channels (e.g., 3 for RGB)
    out_channels=3,  # Number of output channels
    layers_per_block=2,  # Layers per block in the UNet
    block_out_channels=(64, 128, 256),  # Channels in each block
    down_block_types=("DownBlock2D", "DownBlock2D", "CrossAttnDownBlock2D"),  # Types of down blocks
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D", "UpBlock2D"),  # Types of up blocks
    dropout=0.2, # add of dropout to regularize the model
)

In [6]:
# Use of DDIMScheduler instead of DDPMScheduler
noise_scheduler = DDIMScheduler(num_train_timesteps=1000)

In [7]:
unet = unet.to("cuda" if torch.cuda.is_available() else "cpu")
class_net = emb_net.to("cuda" if torch.cuda.is_available() else "cpu")

# Training

In [8]:
def train(
    unet: UNet2DConditionModel,
    emb_net: nn.Module,
    noise_scheduler: DDIMScheduler,
    dataloader: DataLoader,
    num_epochs: int,
    lr: float = 1e-4,
    warmup: int = 4,
) -> np.ndarray[np.float64]:
    """Train a conditional diffusion model.

    Parameters
    ----------
    unet : UNet2DConditionModel
        The unet that will generates new samples.
    emb_net : nn.Module
        The embedding model used to encode the ground truth data for conditional generation.
    noise_scheduler : DDIMScheduler
        Noise scheduler, here for the Denoising Diffusion Implicit Models framework.
    dataloader : DataLoader
        The DataLoader that provides the training samples in batches.
    num_epochs : int
        The total number of epochs to train the unet.
    lr : float, optional
        Models learning rate, by default 1e-4
    warmup : int, optional
        The number of epochs before freezing the embedding network's parameters, by default 4

    Returns
    -------
    np.ndarray[np.float64]
        losses for each epoch.
    """
    epochs = range(num_epochs)
    losses = np.zeros(num_epochs)
    warmup_done = False

    # compile networks for faster inferences
    unet = torch.compile(unet, fullgraph=True)
    emb_net = torch.compile(emb_net, fullgraph=True)

    # here we use AdamW to add weight decay in the training to avoid overfitting
    optimizer = AdamW(chain(unet.parameters(), emb_net.parameters()), lr=lr, weight_decay=1e-4)
    scaler = GradScaler("cuda" if torch.cuda.is_available() else "cpu")  # For mixed precision
    unet.train()

    for epoch in epochs:
        epoch_loss = 0

        if epoch >= warmup:
            warmup_done = True
            # freeze all layers of the embedding network
            for param in emb_net.parameters():
                param.requires_grad = False

        for batch in tqdm(dataloader):
            optimizer.zero_grad()

            # Assuming your dataloader provides images and associated target
            images, labels = batch
            images = images.to(unet.device, non_blocking=True)
            labels = labels.to(dtype=torch.float32, device=unet.device, non_blocking=True)

            with autocast("cuda" if torch.cuda.is_available() else "cpu"):
                # Generate random noise
                noise = torch.randn(images.shape, device=unet.device)

                # Generate random timesteps and apply the noise scheduler
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (images.shape[0],),
                    device=unet.device,
                )

                noisy_images = noise_scheduler.add_noise(images, noise, timesteps)

                if warmup_done:
                    # randomly mask out labels (context)
                    context_mask = torch.bernoulli(torch.zeros(labels.shape[0]) + 0.95).to(unet.device)
                    labels = labels * context_mask.unsqueeze(-1)

                # Compute the class embeddings
                enc_labels = emb_net(labels)

                # Forward pass through the model with labels embeddings
                predicted_noise = unet(
                    noisy_images, timesteps, enc_labels, class_labels=labels
                ).sample

                # Compute l1 loss instead of mse, it provides better robustness to noises
                # which is suited for diffusion models, but the cost is slower training
                loss = torch.nn.functional.l1_loss(predicted_noise, noise)

            # Backward pass and optimization
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(dataloader)
        losses[epoch] = epoch_loss
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}")

    return losses

In [9]:
if not exists(pre_trained_unet) and not exists(pre_trained_emb_net):
    train(unet, emb_net, noise_scheduler, dataloader, num_epochs, lr=1e-4, warmup=warmup)
    unet.save_pretrained(pre_trained_unet)
    torch.save(emb_net.state_dict(), pre_trained_emb_net)
else:
    unet = UNet2DConditionModel.from_pretrained(pre_trained_unet)
    emb_net.load_state_dict(torch.load(pre_trained_emb_net))

100%|██████████| 423/423 [01:29<00:00,  4.71it/s]


Epoch 1/100, Loss: 0.2870138998784072


100%|██████████| 423/423 [00:24<00:00, 17.13it/s]


Epoch 2/100, Loss: 0.18763427992943613


100%|██████████| 423/423 [00:24<00:00, 17.06it/s]


Epoch 3/100, Loss: 0.1540072543225108


100%|██████████| 423/423 [00:24<00:00, 17.01it/s]


Epoch 4/100, Loss: 0.13363078682205637


100%|██████████| 423/423 [00:24<00:00, 17.05it/s]


Epoch 5/100, Loss: 0.12198314468770444


100%|██████████| 423/423 [00:25<00:00, 16.72it/s]


Epoch 6/100, Loss: 0.11215285791291131


100%|██████████| 423/423 [00:24<00:00, 17.16it/s]


Epoch 7/100, Loss: 0.1050925317174154


100%|██████████| 423/423 [00:24<00:00, 16.94it/s]


Epoch 8/100, Loss: 0.10051046544355703


100%|██████████| 423/423 [00:24<00:00, 17.07it/s]


Epoch 9/100, Loss: 0.09519060244937878


100%|██████████| 423/423 [00:24<00:00, 17.15it/s]


Epoch 10/100, Loss: 0.09135213864187822


100%|██████████| 423/423 [00:25<00:00, 16.76it/s]


Epoch 11/100, Loss: 0.08800398044059181


100%|██████████| 423/423 [00:24<00:00, 16.98it/s]


Epoch 12/100, Loss: 0.08537342695268334


100%|██████████| 423/423 [00:25<00:00, 16.86it/s]


Epoch 13/100, Loss: 0.08288691944968897


100%|██████████| 423/423 [00:26<00:00, 15.99it/s]


Epoch 14/100, Loss: 0.0806036891930915


100%|██████████| 423/423 [00:23<00:00, 17.66it/s]


Epoch 15/100, Loss: 0.07935814741761127


100%|██████████| 423/423 [00:23<00:00, 17.71it/s]


Epoch 16/100, Loss: 0.07655676252629176


100%|██████████| 423/423 [00:24<00:00, 17.30it/s]


Epoch 17/100, Loss: 0.0744079325973424


100%|██████████| 423/423 [00:24<00:00, 17.55it/s]


Epoch 18/100, Loss: 0.07297358727941276


100%|██████████| 423/423 [00:23<00:00, 17.69it/s]


Epoch 19/100, Loss: 0.07215310338153624


100%|██████████| 423/423 [00:24<00:00, 17.51it/s]


Epoch 20/100, Loss: 0.07068668293544304


100%|██████████| 423/423 [00:24<00:00, 17.46it/s]


Epoch 21/100, Loss: 0.0696069407353869


100%|██████████| 423/423 [00:24<00:00, 17.11it/s]


Epoch 22/100, Loss: 0.06919419449307113


100%|██████████| 423/423 [00:24<00:00, 17.56it/s]


Epoch 23/100, Loss: 0.06686992307433563


100%|██████████| 423/423 [00:24<00:00, 17.37it/s]


Epoch 24/100, Loss: 0.06690771312366986


100%|██████████| 423/423 [00:24<00:00, 17.58it/s]


Epoch 25/100, Loss: 0.0652566823324554


100%|██████████| 423/423 [00:24<00:00, 17.35it/s]


Epoch 26/100, Loss: 0.06436926682191256


100%|██████████| 423/423 [00:24<00:00, 17.28it/s]


Epoch 27/100, Loss: 0.06313995059676486


100%|██████████| 423/423 [00:24<00:00, 17.55it/s]


Epoch 28/100, Loss: 0.06256039932933823


100%|██████████| 423/423 [00:24<00:00, 17.60it/s]


Epoch 29/100, Loss: 0.06177253579945429


100%|██████████| 423/423 [00:24<00:00, 17.53it/s]


Epoch 30/100, Loss: 0.061591448072986964


100%|██████████| 423/423 [00:23<00:00, 17.68it/s]


Epoch 31/100, Loss: 0.06065226686198097


100%|██████████| 423/423 [00:24<00:00, 17.08it/s]


Epoch 32/100, Loss: 0.05997328428511924


100%|██████████| 423/423 [00:24<00:00, 17.60it/s]


Epoch 33/100, Loss: 0.058710472010974346


100%|██████████| 423/423 [00:24<00:00, 17.41it/s]


Epoch 34/100, Loss: 0.05844710914343243


100%|██████████| 423/423 [00:24<00:00, 17.56it/s]


Epoch 35/100, Loss: 0.058691845999823675


100%|██████████| 423/423 [00:25<00:00, 16.65it/s]


Epoch 36/100, Loss: 0.057845165790113715


100%|██████████| 423/423 [00:24<00:00, 16.95it/s]


Epoch 37/100, Loss: 0.056533919995205906


100%|██████████| 423/423 [00:24<00:00, 17.34it/s]


Epoch 38/100, Loss: 0.056245609290070005


100%|██████████| 423/423 [00:25<00:00, 16.83it/s]


Epoch 39/100, Loss: 0.055637341776497255


100%|██████████| 423/423 [00:24<00:00, 17.20it/s]


Epoch 40/100, Loss: 0.055750721171674435


100%|██████████| 423/423 [00:24<00:00, 17.52it/s]


Epoch 41/100, Loss: 0.05541238420363295


100%|██████████| 423/423 [00:24<00:00, 17.03it/s]


Epoch 42/100, Loss: 0.05399428885738337


100%|██████████| 423/423 [00:24<00:00, 17.60it/s]


Epoch 43/100, Loss: 0.05333889222948264


100%|██████████| 423/423 [00:23<00:00, 17.69it/s]


Epoch 44/100, Loss: 0.054250268393304046


100%|██████████| 423/423 [00:23<00:00, 17.63it/s]


Epoch 45/100, Loss: 0.05346561626621454


100%|██████████| 423/423 [00:24<00:00, 17.54it/s]


Epoch 46/100, Loss: 0.0522174744846973


100%|██████████| 423/423 [00:24<00:00, 17.19it/s]


Epoch 47/100, Loss: 0.052526666225764206


100%|██████████| 423/423 [00:25<00:00, 16.66it/s]


Epoch 48/100, Loss: 0.05254916994706959


100%|██████████| 423/423 [00:24<00:00, 17.23it/s]


Epoch 49/100, Loss: 0.05140245502405133


100%|██████████| 423/423 [00:24<00:00, 17.52it/s]


Epoch 50/100, Loss: 0.05114247580334086


100%|██████████| 423/423 [00:24<00:00, 17.47it/s]


Epoch 51/100, Loss: 0.051196262002625365


100%|██████████| 423/423 [00:24<00:00, 17.33it/s]


Epoch 52/100, Loss: 0.05030989821286912


100%|██████████| 423/423 [00:24<00:00, 17.15it/s]


Epoch 53/100, Loss: 0.05046602478581117


100%|██████████| 423/423 [00:24<00:00, 17.50it/s]


Epoch 54/100, Loss: 0.04999417767597992


100%|██████████| 423/423 [00:23<00:00, 17.65it/s]


Epoch 55/100, Loss: 0.049485269308653844


100%|██████████| 423/423 [00:24<00:00, 17.48it/s]


Epoch 56/100, Loss: 0.049517637158327917


100%|██████████| 423/423 [00:24<00:00, 17.60it/s]


Epoch 57/100, Loss: 0.04875610974874902


100%|██████████| 423/423 [00:24<00:00, 17.11it/s]


Epoch 58/100, Loss: 0.04857597931620641


100%|██████████| 423/423 [00:24<00:00, 17.34it/s]


Epoch 59/100, Loss: 0.04892002533126103


100%|██████████| 423/423 [00:24<00:00, 17.50it/s]


Epoch 60/100, Loss: 0.047860044616122616


100%|██████████| 423/423 [00:24<00:00, 17.52it/s]


Epoch 61/100, Loss: 0.04812704289142685


100%|██████████| 423/423 [00:24<00:00, 17.60it/s]


Epoch 62/100, Loss: 0.04801926852575994


100%|██████████| 423/423 [00:24<00:00, 17.21it/s]


Epoch 63/100, Loss: 0.047075113559022864


100%|██████████| 423/423 [00:24<00:00, 17.58it/s]


Epoch 64/100, Loss: 0.04721243698017817


100%|██████████| 423/423 [00:24<00:00, 17.54it/s]


Epoch 65/100, Loss: 0.047114028549152066


100%|██████████| 423/423 [00:24<00:00, 17.53it/s]


Epoch 66/100, Loss: 0.046737903266722426


100%|██████████| 423/423 [00:24<00:00, 17.56it/s]


Epoch 67/100, Loss: 0.04669207673581498


100%|██████████| 423/423 [00:24<00:00, 17.34it/s]


Epoch 68/100, Loss: 0.0461828406133973


100%|██████████| 423/423 [00:24<00:00, 17.18it/s]


Epoch 69/100, Loss: 0.045915982331508155


100%|██████████| 423/423 [00:24<00:00, 17.13it/s]


Epoch 70/100, Loss: 0.04557445388172817


100%|██████████| 423/423 [00:24<00:00, 17.31it/s]


Epoch 71/100, Loss: 0.04555809599992513


100%|██████████| 423/423 [00:24<00:00, 17.48it/s]


Epoch 72/100, Loss: 0.04578078593662445


100%|██████████| 423/423 [00:23<00:00, 17.64it/s]


Epoch 73/100, Loss: 0.04518012978118926


100%|██████████| 423/423 [00:24<00:00, 17.23it/s]


Epoch 74/100, Loss: 0.04467465387696915


100%|██████████| 423/423 [00:24<00:00, 17.61it/s]


Epoch 75/100, Loss: 0.045210662639169265


100%|██████████| 423/423 [00:24<00:00, 17.54it/s]


Epoch 76/100, Loss: 0.044861474983728805


100%|██████████| 423/423 [00:24<00:00, 17.54it/s]


Epoch 77/100, Loss: 0.044368763755892465


100%|██████████| 423/423 [00:24<00:00, 17.42it/s]


Epoch 78/100, Loss: 0.04463015638520813


100%|██████████| 423/423 [00:24<00:00, 16.93it/s]


Epoch 79/100, Loss: 0.044500109424757336


100%|██████████| 423/423 [00:24<00:00, 17.48it/s]


Epoch 80/100, Loss: 0.04397887296657613


100%|██████████| 423/423 [01:00<00:00,  6.97it/s]


Epoch 81/100, Loss: 0.04419304048268226


100%|██████████| 423/423 [00:24<00:00, 17.38it/s]


Epoch 82/100, Loss: 0.04354199155037848


100%|██████████| 423/423 [00:24<00:00, 17.43it/s]


Epoch 83/100, Loss: 0.04327129266152145


100%|██████████| 423/423 [00:24<00:00, 17.26it/s]


Epoch 84/100, Loss: 0.04377851804308858


100%|██████████| 423/423 [00:25<00:00, 16.65it/s]


Epoch 85/100, Loss: 0.04353387995532782


100%|██████████| 423/423 [00:24<00:00, 17.54it/s]


Epoch 86/100, Loss: 0.04364471552325479


100%|██████████| 423/423 [00:24<00:00, 17.56it/s]


Epoch 87/100, Loss: 0.0425736058468162


100%|██████████| 423/423 [00:24<00:00, 17.23it/s]


Epoch 88/100, Loss: 0.04326137867767196


100%|██████████| 423/423 [00:24<00:00, 17.50it/s]


Epoch 89/100, Loss: 0.04221380081899623


100%|██████████| 423/423 [00:25<00:00, 16.55it/s]


Epoch 90/100, Loss: 0.04286487013031123


100%|██████████| 423/423 [00:25<00:00, 16.46it/s]


Epoch 91/100, Loss: 0.0429401301446395


100%|██████████| 423/423 [00:25<00:00, 16.91it/s]


Epoch 92/100, Loss: 0.04261350636997967


100%|██████████| 423/423 [00:26<00:00, 16.25it/s]


Epoch 93/100, Loss: 0.041972883600503844


100%|██████████| 423/423 [00:25<00:00, 16.83it/s]


Epoch 94/100, Loss: 0.042327682485464896


100%|██████████| 423/423 [00:25<00:00, 16.46it/s]


Epoch 95/100, Loss: 0.04187175005797658


100%|██████████| 423/423 [00:24<00:00, 17.32it/s]


Epoch 96/100, Loss: 0.04198601044027518


100%|██████████| 423/423 [00:24<00:00, 17.41it/s]


Epoch 97/100, Loss: 0.04165022790291067


100%|██████████| 423/423 [00:25<00:00, 16.80it/s]


Epoch 98/100, Loss: 0.04148492707604494


100%|██████████| 423/423 [00:24<00:00, 16.97it/s]


Epoch 99/100, Loss: 0.04155582878254266


100%|██████████| 423/423 [00:25<00:00, 16.90it/s]


Epoch 100/100, Loss: 0.04101051134332556


# Load the model and try it

In [10]:
class ConditionalDDIMPipeline(DDIMPipeline):
    def __init__(
        self, unet: UNet2DConditionModel, class_net: CustomSequential, scheduler: DDIMScheduler
    ) -> None:
        super().__init__(unet=unet, scheduler=scheduler)
        self.class_net = class_net
        self.class_net.eval()
        self.register_modules(class_net=class_net)
        self.unet.eval()

    @torch.no_grad()
    def __call__(
        self,
        class_label: list[list[float]],
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        eta: float = 0.0,
        num_inference_steps: int = 1000,
        use_clipped_model_output: Optional[bool] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:
        r"""
        The call function to the pipeline for generation.

        Args:
            class_label (list[list[float]]):
                list of one-hot examples. len(class_label) represents the number of examples to generate.
            generator (`torch.Generator`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
                DDIM and `1` corresponds to DDPM.
            num_inference_steps (`int`, *optional*, defaults to 1000):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            use_clipped_model_output (`bool`, *optional*, defaults to `None`):
                If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
                downstream to the scheduler (use `None` for schedulers which don't support this argument).
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        batch_size = len(class_label)
        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)

        # get encoded ground truth
        labels = torch.tensor(class_label, device=self.device)
        with autocast(str(self.unet.device)):
            enc_labels = self.class_net(labels)

            # set step values
            self.scheduler.set_timesteps(num_inference_steps)

            for t in self.progress_bar(self.scheduler.timesteps):
                # 1. predict noise model_output
                model_output = self.unet(image, t, enc_labels, class_labels=labels, return_dict=False)[
                    0
                ]

                # 2. predict previous mean of image x_t-1 and add variance depending on eta
                # eta corresponds to η in paper and should be between [0, 1]
                # do x_t -> x_t-1
                image = self.scheduler.step(
                    model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
                ).prev_sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()

        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

In [11]:
# Create the DDIM pipeline
pipeline = ConditionalDDIMPipeline(unet=unet, class_net=class_net, scheduler=noise_scheduler)
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

ConditionalDDIMPipeline {
  "_class_name": "ConditionalDDIMPipeline",
  "_diffusers_version": "0.30.0",
  "class_net": [
    "__main__",
    "CustomSequential"
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ]
}

In [12]:
generated_image = pipeline(
    [
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],

    ],
    num_inference_steps=10,
)

100%|██████████| 10/10 [00:00<00:00, 62.50it/s]


In [13]:
fig = plot_generated_images(generated_image.images, 5, 4)
fig.show()