## Train a diffusion model

(The notebook is copied from https://huggingface.co/docs/diffusers/en/tutorials/basic_training).

Unconditional image generation is a popular application of diffusion models that generates images that look like those in dataset used for training. Typlically, the best results are obtained from finetuning a pretrained model on a specific dataset. You can find many of these checkpoints on the [Hub](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model), but if you can't find one you like, you can always train your own! 

This tutorial will teach you how to train a [UNet2DModel](https://huggingface.co/docs/diffusers/v0.32.2/en/api/models/unet2d#diffusers.UNet2DModel) from scratch on a subset of the [Smithsonian Butterflies] (https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset to generate your own butterflies. 

This training tutorial is based on the [Training with Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook. 


Before you begin, make sure you have `Datasets` installed to load and preprocess image datasets, and `Accelerate` to simplify training on any number of GPUs. The following command will also install [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize training metrics (you can also use [Weights and Basics] (https://docs.wandb.ai/) to track your training). 

In [None]:
!pip install diffusers[training]

### Training configuration  

For convenience create a `TrainingConfig` class containing the training hyperparameters (feel free to adjust them):

In [None]:
from dataclasses import dataclasses

@dataclasses
class TrainingConfig: 
    image_size = 128 # the genrated image resolution
    train_batch_size = 16
    eval_batch_size = 16 # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = 'fp16'   # 'no' for float32, 'fp16' for automatic mixed precision
    output_dir = 'ddpm-butterflies-128'

    push_to_hub = False   
    hub_model_id = '<your-username>'/<my-awesome-model>
    hub_private_repo = None
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()


### Load the dataset

You can easily load the Smithsonian Butterflies dataset with the 🤗 Datasets library:



In [None]:
from dataset import load_dataset
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name,split='train')



🤗 Datasets uses the [Image](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Image) feature to automatically decode the image data and load it as a [PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html) which we can visualize:

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
    axs[i].imshow(image)
    axs[i].set_axis_off()
fig.show()

### Preprocess data

The images are all different sizes though, so you’ll need to preprocess them first:
- `Resize` changes the image size to the one defined in `config.image_size`
- `RandomHorizontalFlip` augments the dataset by randomly mirroring the images
- `Normalize` is important to rescale the pixel values into a [-1,1] range, which is what the model expects. 

In [None]:
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
])

Use Datasets' [set_transform](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Dataset.set_transform) method to apply the `preprocess` function on the fly during training.

In [None]:
def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

Now you’re ready to wrap the dataset in a [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader) for training!

In [None]:
import torch
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

### Create a UNet2D Model

Pretrained models in Diffusers are easily created from their model class with the parameters you want. For example, to create a `UNet2DModel`: 

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size = config.sample_size,
    in_channels = 3, 
    out_channels = 3,
    layers_per_block = 2, # how many ResNet layers to use per U-Net block
    block_out_channels = (128, 128, 256, 256, 512, 512), # number of output channels in each U-Net block
    down_block_types = (
        'Downblock2D', ## a regular Resnet downsampling block.
        'Downblock2D',
        'Downblock2D',
        'Downblock2D',
        'AttnDownblock2D',
        'Downblock2D',
    )
    up_block_types = (
        'Upblock2D',
        'AttnUpblock2D',
        'Upblock2D',
        'Upblock2D',
        'Upblock2D',
        'Upblock2D',
    ),
)

It is often a good idea to quickly check the sample image shape matches the model output shape:

In [None]:
sample_image = dataset[0]['images'].unsqueeze(0)
print("Input shape:", sample_image.shape)

print("Output shape:"), model(sample_image, timestep=0).sample.shape

### Create a scheduler
You'll need a scheduler to add some noise to the images. 
The scheduler behaves differently depending on whether you're using the model for training or inference. During inference, the scheduler generates the image from the noise. During training, the scheduler takes a model output -- or a sample -- from a specific point in the diffusion process and applies noise to the image according to a noise scheduler and an update rule. 

Let's take a look at the [DDPMScheduler](https://huggingface.co/docs/diffusers/v0.32.2/en/api/schedulers/ddpm#diffusers.DDPMScheduler) and use the `add_noise` method to add some random noise to the sample image from before: 

In [None]:
import torch

from PIL import Image
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_iamge,noise,timesteps)

Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

The training objective of the model is to predict the noise added to the image. The loss at this step can be calculated as: 

In [None]:
import torch.nn.functional as F

noise_pred = model(noisy_image, timesteps).sample
loss = F.mse_loss(noise_pred, noise)


### Train the model

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer, 
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),

Then, you'll need a way to evaluate the model. For evaluation, you can use the [DDPMPipeline](https://huggingface.co/docs/diffusers/v0.32.2/en/api/pipelines/ddpm#diffusers.DDPMPipeline) to generate a batch of sample images and save it as a grid:

In [None]:
from diffusers import DDPMPipeline
from diffusers.util import make_image_grid
import os

def evaluate(config, epoch, pipeline):
    ## sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is 'List[PIL.Image]'
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator(device='cpu').manual_seed(config.seed), ## use a separate torch generator to avoid rewinding the random state of the main training loop
    ).images

    # make a grid out of the images
    image_grid = make_image_grid(images, rows=4, cols=4)

    # save the images
    test_dir = os.path.join(config.output_dir, "test_images")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

Now you can wrp all these components together in a training loop with `Accelerate` for easy `TensorBoard` logging, gradient accumulation, and mixed precision training. 

In [None]:
from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with='tensorboard',
        project_dir=os.path.join(config.output_dir, "logs"),
    )

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
            if config.push_to_hub:
                repo_id = create_repo(
                    repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
                    ).repo_id
            accelerator.init_trackers('train_example')
    
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            # sample noise to add to the image
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]
        
            # sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.num_timesteps, (bs,), device=clean_images.device,
                                  dtype=torch.int64
            )

        # Add noise to the clean images according to the noise magnitude at each timestep 
        # this is the forward diffusion process
        noise_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        with accelerator.accumulate(model):
            # predict the noise schedule
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        progress_bar.update(1)
        logs = {'loss': loss.detach().item(), 'lr':lr_scheduler.get_last_lr()[0], 'step': global_step}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)
        global_step += 1

    # After each epoch you optionally sample some demo images with evaluate() and save the model
    if accelerator.is_main_process:
        pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

        if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
            if config.push_to_hub:
                upload_folder(
                    repo_id=repo_id,
                    folder_path=config.output_dir,
                    commit_message=f"Epoch {epoch}",
                    ignore_patterns=["step_*","epoch_*"],
                )
            else:
                pipeline.save_pretrained(config.output_dir)
                



In [None]:
from accelerate import notebook_launcher
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)


Once training is complete, take a look at the final iamges generated by your diffusion model

In [None]:
import glob

sampel_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])