In [1]:
from os.path import exists

import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler
from diffusers.models import UNet2DModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

from dlai_lib.diffusion_utilities import CustomDataset, transform
from plotly_tools import plot_generated_images

  from .autonotebook import tqdm as notebook_tqdm


# Hyperparameters

In [2]:
batch_size = 100
num_epochs = 32

# Dataset

In [3]:
# load dataset using the deeplearning.ai course
dataset = CustomDataset("dlai_lib/sprites_1788_16x16.npy", "dlai_lib/sprite_labels_nc_1788_16x16.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


# Model definition

In [4]:
# Define the UNet model
# Note: this model use less parameters compared to deeplearning.ai course as I found to be not necessary that huge
model = UNet2DModel(
    sample_size=(16,16),                              # Input image size
    in_channels=3,                                    # Number of input channels (e.g., 3 for RGB)
    out_channels=3,                                   # Number of output channels
    layers_per_block=2,                               # Layers per block in the UNet
    block_out_channels=(128, 64),                     # Channels in each block
    down_block_types=("DownBlock2D", "DownBlock2D"),  # Types of down blocks
    up_block_types=("UpBlock2D", "UpBlock2D")         # Types of up blocks
)

# Define the DDPM scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=500)

In [5]:
# Having a CUDA compatible GPU is recommended to be faster to train the model and have faster inferences
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Train or load the previously learned model

In [6]:

def train(unet: UNet2DModel, noise_scheduler: DDPMScheduler, dataloader: DataLoader, num_epochs: int, lr: float) -> None:
    """Train the unet given its noise_scheduler and a dataloader.

    Parameters
    ----------
    unet : UNet2DModel
        _description_
    noise_scheduler : DDPMScheduler
        _description_
    dataloader : DataLoader
        The dataloader containing the images to reproduce.
    num_epochs : int
        The number of epochs to train the unet.
    lr : float
        The learning rate to use to train the unet.
    """
    epochs = range(num_epochs)
    losses = np.zeros(num_epochs)

    optimizer = Adam(unet.parameters(), lr=lr)
    unet.train()

    for epoch in epochs:
        epoch_loss = 0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()

            # Assuming your dataloader provides images and target noise
            images, _ = batch
            images = images.to(unet.device)

            # Generate random noise
            noise = torch.randn(images.shape).to(unet.device)

            # Forward pass through the model
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (images.shape[0],), device=unet.device).long()
            noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
            predicted_noise = unet(noisy_images, timesteps).sample

            # Compute loss (mean squared error between actual and predicted noise)
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(dataloader)
        losses[epoch] = epoch_loss
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}")

In [7]:
pre_trained_model = "weights/non_conditional_model.pth"

In [8]:
if not exists(pre_trained_model):
    train(model, noise_scheduler, dataloader, num_epochs, lr=1e-4)
    model.save_pretrained(pre_trained_model)
else:
    model = UNet2DModel.from_pretrained(pre_trained_model)

100%|██████████| 894/894 [02:08<00:00,  6.98it/s]


Epoch 1/32, Loss: 0.14205249846481638


100%|██████████| 894/894 [02:05<00:00,  7.13it/s]


Epoch 2/32, Loss: 0.09112003538549213


100%|██████████| 894/894 [01:55<00:00,  7.71it/s]


Epoch 3/32, Loss: 0.07890236825667639


100%|██████████| 894/894 [01:53<00:00,  7.89it/s]


Epoch 4/32, Loss: 0.07194209593021096


100%|██████████| 894/894 [01:53<00:00,  7.89it/s]


Epoch 5/32, Loss: 0.06634500102708804


100%|██████████| 894/894 [01:53<00:00,  7.89it/s]


Epoch 6/32, Loss: 0.06253846915609618


100%|██████████| 894/894 [01:57<00:00,  7.59it/s]


Epoch 7/32, Loss: 0.05900279245853957


100%|██████████| 894/894 [01:57<00:00,  7.58it/s]


Epoch 8/32, Loss: 0.05604800331495886


100%|██████████| 894/894 [02:06<00:00,  7.07it/s]


Epoch 9/32, Loss: 0.05314782570830861


100%|██████████| 894/894 [02:04<00:00,  7.17it/s]


Epoch 10/32, Loss: 0.05155665113880264


100%|██████████| 894/894 [02:04<00:00,  7.17it/s]


Epoch 11/32, Loss: 0.04969527572601767


100%|██████████| 894/894 [02:04<00:00,  7.17it/s]


Epoch 12/32, Loss: 0.04728619077386672


100%|██████████| 894/894 [02:04<00:00,  7.20it/s]


Epoch 13/32, Loss: 0.04584672643909318


100%|██████████| 894/894 [01:50<00:00,  8.09it/s]


Epoch 14/32, Loss: 0.045034246625466234


100%|██████████| 894/894 [01:50<00:00,  8.11it/s]


Epoch 15/32, Loss: 0.04340523061127937


100%|██████████| 894/894 [01:59<00:00,  7.50it/s]


Epoch 16/32, Loss: 0.04184080003000279


100%|██████████| 894/894 [02:06<00:00,  7.06it/s]


Epoch 17/32, Loss: 0.04120134479955872


100%|██████████| 894/894 [02:05<00:00,  7.11it/s]


Epoch 18/32, Loss: 0.03976706568760093


100%|██████████| 894/894 [02:05<00:00,  7.11it/s]


Epoch 19/32, Loss: 0.038842404452913024


100%|██████████| 894/894 [02:05<00:00,  7.10it/s]


Epoch 20/32, Loss: 0.038221116021475536


100%|██████████| 894/894 [02:04<00:00,  7.20it/s]


Epoch 21/32, Loss: 0.037277157645530884


100%|██████████| 894/894 [01:50<00:00,  8.09it/s]


Epoch 22/32, Loss: 0.03640188445530882


100%|██████████| 894/894 [01:50<00:00,  8.11it/s]


Epoch 23/32, Loss: 0.03583769883827142


100%|██████████| 894/894 [01:50<00:00,  8.11it/s]


Epoch 24/32, Loss: 0.03517967219187376


100%|██████████| 894/894 [01:50<00:00,  8.10it/s]


Epoch 25/32, Loss: 0.03445411324959407


100%|██████████| 894/894 [01:50<00:00,  8.10it/s]


Epoch 26/32, Loss: 0.0339977684195343


100%|██████████| 894/894 [01:50<00:00,  8.10it/s]


Epoch 27/32, Loss: 0.032896325465133394


100%|██████████| 894/894 [01:50<00:00,  8.10it/s]


Epoch 28/32, Loss: 0.032666923649092384


100%|██████████| 894/894 [01:53<00:00,  7.91it/s]


Epoch 29/32, Loss: 0.03209584288640777


100%|██████████| 894/894 [02:02<00:00,  7.31it/s]


Epoch 30/32, Loss: 0.031808388721556204


100%|██████████| 894/894 [01:52<00:00,  7.98it/s]


Epoch 31/32, Loss: 0.03151194132881143


100%|██████████| 894/894 [01:55<00:00,  7.75it/s]

Epoch 32/32, Loss: 0.031045322552306673





# Try some inferences with our model

In [9]:
# Create the DDPM pipeline
pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")

DDPMPipeline {
  "_class_name": "DDPMPipeline",
  "_diffusers_version": "0.27.2",
  "scheduler": [
    "diffusers",
    "DDPMScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DModel"
  ]
}

In [10]:
generated_image = pipeline(batch_size=16, num_inference_steps=500)

100%|██████████| 500/500 [00:05<00:00, 96.83it/s]


In [11]:
fig = plot_generated_images(generated_image.images, 4, 4)
fig.show()