In [1]:
from os.path import exists

import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler
from diffusers.models import UNet2DModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

from dlai_lib.diffusion_utilities import CustomDataset, transform
from plotly_tools import plot_generated_images

  from .autonotebook import tqdm as notebook_tqdm


# Hyperparameters

In [2]:
batch_size = 100
num_epochs = 32

# Dataset

In [3]:
# load dataset using the deeplearning.ai course
dataset = CustomDataset("dlai_lib/sprites_1788_16x16.npy", "dlai_lib/sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


# Model definition

In [4]:
# Define the UNet model
# Note: this model use less parameters compared to deeplearning.ai course as it is not necessary to have such huge model for this task
model = UNet2DModel(
    sample_size=(16,16),                              # Input image size
    in_channels=3,                                    # Number of input channels (e.g., 3 for RGB)
    out_channels=3,                                   # Number of output channels
    layers_per_block=2,                               # Layers per block in the UNet
    block_out_channels=(128, 64),                     # Channels in each block
    down_block_types=("DownBlock2D", "DownBlock2D"),  # Types of down blocks
    up_block_types=("UpBlock2D", "UpBlock2D")         # Types of up blocks
)

# Define the DDPM scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=500)

In [5]:
# Having a CUDA compatible GPU is recommended to be faster to train the model and have faster inferences
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Train or load the previously learned model

In [6]:

def train(unet: UNet2DModel, noise_scheduler: DDPMScheduler, dataloader: DataLoader, num_epochs: int, lr: float) -> None:
    """Train the unet given its noise_scheduler and a dataloader.

    Parameters
    ----------
    unet : UNet2DModel
        The model unet to train.
    noise_scheduler : DDPMScheduler
        noise scheduler to use while training.
    dataloader : DataLoader
        The dataloader containing the images to reproduce.
    num_epochs : int
        The number of epochs to train the unet.
    lr : float
        The learning rate to use to train the unet.
    """
    epochs = range(num_epochs)
    losses = np.zeros(num_epochs)

    optimizer = Adam(unet.parameters(), lr=lr)
    unet.train()

    for epoch in epochs:
        epoch_loss = 0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()

            # Assuming your dataloader provides images and targets (not used here)
            images, _ = batch
            images = images.to(unet.device, non_blocking=True)

            # Generate random noise
            noise = torch.randn(images.shape, device=unet.device)

            # Forward pass through the model
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (images.shape[0],), device=unet.device)
            noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
            predicted_noise = unet(noisy_images, timesteps).sample

            # Compute loss (mean squared error between actual and predicted noise)
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(dataloader)
        losses[epoch] = epoch_loss
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}")

In [7]:
pre_trained_model = "weights/non_conditional_model.pth"

In [8]:
if not exists(pre_trained_model):
    train(model, noise_scheduler, dataloader, num_epochs, lr=1e-4)
    model.save_pretrained(pre_trained_model)
else:
    model = UNet2DModel.from_pretrained(pre_trained_model)

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 894/894 [00:49<00:00, 18.00it/s]


Epoch 1/32, Loss: 0.14239213854007807


100%|██████████| 894/894 [00:49<00:00, 18.20it/s]


Epoch 2/32, Loss: 0.09175340646265337


100%|██████████| 894/894 [00:49<00:00, 18.00it/s]


Epoch 3/32, Loss: 0.0783772516737315


100%|██████████| 894/894 [00:49<00:00, 17.99it/s]


Epoch 4/32, Loss: 0.0714373614564038


100%|██████████| 894/894 [00:49<00:00, 18.18it/s]


Epoch 5/32, Loss: 0.0665116410169092


100%|██████████| 894/894 [00:49<00:00, 18.11it/s]


Epoch 6/32, Loss: 0.06217689540912241


100%|██████████| 894/894 [00:49<00:00, 18.16it/s]


Epoch 7/32, Loss: 0.05870652072114966


100%|██████████| 894/894 [00:49<00:00, 18.23it/s]


Epoch 8/32, Loss: 0.056081055457259985


100%|██████████| 894/894 [00:49<00:00, 18.06it/s]


Epoch 9/32, Loss: 0.05328660390196031


100%|██████████| 894/894 [00:48<00:00, 18.36it/s]


Epoch 10/32, Loss: 0.05204211333370715


100%|██████████| 894/894 [00:48<00:00, 18.42it/s]


Epoch 11/32, Loss: 0.049226853275625766


100%|██████████| 894/894 [00:48<00:00, 18.27it/s]


Epoch 12/32, Loss: 0.0483685890733529


100%|██████████| 894/894 [00:49<00:00, 17.95it/s]


Epoch 13/32, Loss: 0.04655174907480244


100%|██████████| 894/894 [00:49<00:00, 18.01it/s]


Epoch 14/32, Loss: 0.044993478706902436


100%|██████████| 894/894 [00:49<00:00, 18.06it/s]


Epoch 15/32, Loss: 0.043496422788390776


100%|██████████| 894/894 [00:49<00:00, 18.22it/s]


Epoch 16/32, Loss: 0.04201661035943551


100%|██████████| 894/894 [00:48<00:00, 18.25it/s]


Epoch 17/32, Loss: 0.0413837184758661


100%|██████████| 894/894 [00:49<00:00, 18.20it/s]


Epoch 18/32, Loss: 0.039628636623238955


100%|██████████| 894/894 [00:48<00:00, 18.50it/s]


Epoch 19/32, Loss: 0.03936571770991455


100%|██████████| 894/894 [00:48<00:00, 18.25it/s]


Epoch 20/32, Loss: 0.03822532262219745


100%|██████████| 894/894 [00:49<00:00, 18.13it/s]


Epoch 21/32, Loss: 0.03731627515840117


100%|██████████| 894/894 [00:49<00:00, 18.20it/s]


Epoch 22/32, Loss: 0.03674632924456791


100%|██████████| 894/894 [00:49<00:00, 18.13it/s]


Epoch 23/32, Loss: 0.03604246864316181


100%|██████████| 894/894 [00:47<00:00, 18.68it/s]


Epoch 24/32, Loss: 0.03495211016503423


100%|██████████| 894/894 [00:49<00:00, 18.20it/s]


Epoch 25/32, Loss: 0.03454564049448276


100%|██████████| 894/894 [00:49<00:00, 18.24it/s]


Epoch 26/32, Loss: 0.033996071137271204


100%|██████████| 894/894 [00:47<00:00, 18.99it/s]


Epoch 27/32, Loss: 0.03312669113336727


100%|██████████| 894/894 [00:45<00:00, 19.63it/s]


Epoch 28/32, Loss: 0.0324657446906904


100%|██████████| 894/894 [00:45<00:00, 19.64it/s]


Epoch 29/32, Loss: 0.03251824290252039


100%|██████████| 894/894 [00:45<00:00, 19.67it/s]


Epoch 30/32, Loss: 0.032081249251261656


100%|██████████| 894/894 [00:46<00:00, 19.21it/s]


Epoch 31/32, Loss: 0.031478460196830683


100%|██████████| 894/894 [00:48<00:00, 18.31it/s]


Epoch 32/32, Loss: 0.03103836457884685


# Try some inferences with our model

In [9]:
# Create the DDPM pipeline
pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

DDPMPipeline {
  "_class_name": "DDPMPipeline",
  "_diffusers_version": "0.27.2",
  "scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DModel"
  ]
}

In [10]:
generated_image = pipeline(batch_size=16, num_inference_steps=500)

100%|██████████| 500/500 [00:02<00:00, 170.12it/s]


In [11]:
fig = plot_generated_images(generated_image.images, 4, 4)
fig.show()