
# DreamBooth with Stable Diffusion

[DreamBooth] is a personalization technique—introduced by Google Research and Boston University in 2022—that fine-tunes a full text-to-image diffusion model (such as Imagen or Stable Diffusion) using just a handful (typically 3-5) of reference images. By associating a unique identifier token with the subject's class (for example, “a photo of \[XYZ] dog”), the model learns to faithfully generate that specific subject in novel contexts, poses, and lighting conditions. This allows for high-fidelity, subject-driven image synthesis even with minimal training data.

[DreamBooth]: https://en.wikipedia.org/wiki/DreamBooth?utm_source=chatgpt.com "DreamBooth"

## What We're Going to Do

1. Install dependencies
2. Upload your training images (3-10 high-quality images of your subject)
3. Train a DreamBooth model on your subject
4. Generate new images of your subject


In [None]:
# @title Install Dependencies {"display-mode":"form"}

# @markdown **Run this first.** We need to import _all_ the essential Python libraries and modules required for DreamBooth training and image generation. It brings in tools for handling operating system tasks, math operations, randomness, and numerical computing (like `numpy`). It includes PyTorch for deep learning, Hugging Face’s `diffusers` and `transformers` libraries for working with Stable Diffusion and text/image processing, and utilities for dataset handling, image manipulation (Pillow), progress tracking (`tqdm`), and training optimizations such as gradient checkpointing and learning rate scheduling. Collectively, these imports set up the foundation needed to prepare datasets, train a DreamBooth model, and generate images from text prompts.

%pip install -q git+https://github.com/huggingface/diffusers
%pip install -q accelerate tensorboard transformers ftfy bitsandbytes
%pip install xformers -q --index-url https://download.pytorch.org/whl/cu124
%pip install -qq bitsandbytes

# ✅ Verify
import torch, torchvision, torchaudio, xformers
from google.colab import (files, output)

output.clear()

from pathlib import Path
import argparse # Used for parsing command-line arguments (though used here via Namespace)
import itertools # Provides tools for creating iterators for efficient looping
import math # Provides mathematical functions
import os # Provides a way to interact with the operating system (e.g., file paths)
from contextlib import nullcontext # A context manager that does nothing, useful as a placeholder
import gc # Garbage collection so we don't run out of memory.

import random # Used for generating random numbers, potentially for seeding
import numpy as np # A fundamental package for scientific computing with Python, used for numerical operations
import torch # The core PyTorch library for deep learning
import torch.nn.functional as F # Provides a collection of functions for neural networks
import torch.utils.checkpoint # Used for gradient checkpointing to save memory during training
from torch.utils.data import Dataset # An abstract class representing a dataset, used for creating custom datasets

import PIL # Pillow library, used for image manipulation
from accelerate import Accelerator # Hugging Face library for simplifying distributed training
from accelerate.logging import get_logger # Function to get a logger for logging training progress
from accelerate.utils import set_seed # Function to set the random seed for reproducibility
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel # Diffusers library components for diffusion models
from diffusers.optimization import get_scheduler # Function to get a learning rate scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker # Component for checking generated images for safety
from PIL import Image # Image class from Pillow
from torchvision import transforms # Provides common image transformations
from tqdm.auto import tqdm # A library to display progress bars
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer # Transformers library components for working with text models

print("👨‍🎤 Environment ready to rock.")

## Upload Your Images

Next up, we're going to upload some pictures of our subject. It can be a person, a dog, your favorite sock puppet. I don't care. Choose 3–10 images of your subject to upload. **Fun Fact**: Google Colab has a built-in way for you to upload files directly into your notebook's storage.

This code snippet creates a dedicated folder to store training images, then checks whether the folder is empty before prompting you to upload new files. If no images are present, it opens a file picker for you to select and upload pictures from your computer, otherwise it skips the upload process. Throughout, it prints status messages to confirm each step, ensuring your images are properly prepared and stored for use in DreamBooth training.


In [None]:
# Name your concept (e.g., your dog's name)
concept_name = "my-dog" #@param {"type": "string"}
instance_data_dir = os.path.join("/content/", concept_name)

# Create the directory
os.makedirs(instance_data_dir, exist_ok=True)

print(f"Directory '{instance_data_dir}' created for your images. 💪")

# Check if the directory is empty before prompting for upload.
# If you already have images, I'm not going to make you keep re-uploading them.
if not os.listdir(instance_data_dir):
  print(f"Uploading images to {instance_data_dir}…")
  uploaded = files.upload(instance_data_dir)
else:
  print(f"Directory '{instance_data_dir}' is not empty. Skipping image upload.")

output.clear()
print("🩻 Images uploaded successfully.")

Next, we'll define two custom PyTorch Dataset classes—`DreamBoothDataset` and `PromptDataset`—which handle the data preparation for DreamBooth training.

`DreamBoothDataset` loads and preprocesses your subject images (and optionally, class images for prior preservation), applies resizing and normalization transformations, and tokenizes prompts for use in model training. It ensures a balanced dataset when both instance and class images are used and returns processed image–prompt pairs for each training step.

`PromptDataset` is a simpler utility that repeatedly provides a given text prompt for generating batches during image generation or testing.

Together, these classes make it easy to feed properly formatted images and prompts into the DreamBooth model.


In [None]:
#@title Setup the Classes {"display-mode":"form"}

# Import necessary libraries and modules
from pathlib import Path # Used for working with file paths in a platform-independent way
from torchvision import transforms # Provides common image transformations

# Define a custom dataset class for DreamBooth training
class DreamBoothDataset(Dataset):
    # Initialize the dataset
    def __init__(
        self,
        instance_data_dir, # Directory containing the instance images (images of your subject)
        instance_prompt, # Prompt describing the instance images (e.g., "a photo of sks dog")
        tokenizer, # Tokenizer for processing text prompts
        class_data_root=None, # Optional directory for class images (for prior preservation)
        class_prompt=None, # Optional prompt describing the class images
        size=512, # The resolution to resize images to (e.g., 512x512)
        center_crop=False, # Whether to center crop images after resizing
    ):
        self.size = size # Store the image size
        self.center_crop = center_crop # Store the center crop setting
        self.tokenizer = tokenizer # Store the tokenizer

        # Set up paths for instance images
        self.instance_data_dir = Path(instance_data_dir)
        if not self.instance_data_dir.exists():
            # Raise an error if the instance data directory doesn't exist
            raise ValueError("Instance images root doesn't exists.")

        # Get a list of all image paths in the instance directory
        self.instance_images_path = list(Path(instance_data_dir).iterdir())
        self.num_instance_images = len(self.instance_images_path) # Count the number of instance images
        self.instance_prompt = instance_prompt # Store the instance prompt
        self._length = self.num_instance_images # Initial dataset length based on instance images

        # Set up paths for class images if prior preservation is used
        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True) # Create the class image directory if it doesn't exist
            self.class_images_path = list(Path(class_data_root).iterdir()) # Get a list of class image paths
            self.num_class_images = len(self.class_images_path) # Count the number of class images
            # Update dataset length to be the maximum of instance and class images for balancing
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt # Store the class prompt
        else:
            self.class_data_root = None # Set class data root to None if not used

        # Define the image transformations to apply to images
        self.image_transforms = transforms.Compose(
            [
                # Resize the image to the specified size using bilinear interpolation
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                # Apply center crop or random crop based on the setting
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(), # Convert the image to a PyTorch tensor
                # Normalize the image with mean 0.5 and standard deviation 0.5
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    # Return the length of the dataset
    def __len__(self):
        return self._length

    # Get an item from the dataset at the given index
    def __getitem__(self, index):
        example = {} # Initialize a dictionary to store the example data
        # Load and process the instance image
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB") # Convert to RGB if not already
        example["instance_images"] = self.image_transforms(instance_image) # Apply transformations

        # Tokenize the instance prompt
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids # Get the input IDs from the tokenized prompt

        # If prior preservation is used, load and process a class image and tokenize the class prompt
        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB") # Convert to RGB if not already
            example["class_images"] = self.image_transforms(class_image) # Apply transformations
            # Tokenize the class prompt
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids # Get the input IDs from the tokenized prompt

        return example # Return the example dictionary


# Define a simple dataset for generating prompts
class PromptDataset(Dataset):
    # Initialize the dataset
    def __init__(self, prompt, num_samples):
        self.prompt = prompt # Store the prompt
        self.num_samples = num_samples # Store the number of samples to generate

    # Return the length of the dataset
    def __len__(self):
        return self.num_samples

    # Get an item from the dataset at the given index
    def __getitem__(self, index):
        example = {} # Initialize a dictionary to store the example data
        example["prompt"] = self.prompt # Store the prompt in the example
        example["index"] = index # Store the index in the example
        return example # Return the example dictionary

# Generating the Class Images

This code block ensures you have enough "class images" (generic images of your subject’s category, like other cats or dogs) to help prevent overfitting during DreamBooth training. If the target folder has fewer images than required, it loads a pre-trained Stable Diffusion model and automatically generates the missing images using a class prompt you provide. These images are saved to a specified directory and later used alongside your subject’s photos to maintain the model’s ability to generate diverse examples from the broader class, reducing the risk of catastrophic forgetting.

In [None]:
#@title Generate Class Images {"display-mode":"form"}

#@markdown `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true}

#@markdown If the `prior_preservation_class_folder` is empty, images for the class will be generated with the class prompt. Otherwise, fill this folder with images of items on the same class as your concept (but not images of the concept itself)
prior_preservation_class_folder = "./class_images" #@param {type:"string"}
prior_preservation_class_prompt = "a photo of a pit bull" #@param {type:"string"}

#@markdown `prior_preservation_weight` determins how strong the class for prior preservation should be
prior_loss_weight = 0.5 #@param {type: "number"}

# Set the desired number of class images to generate
num_class_images = 12

# Set the batch size for generating class images
sample_batch_size = 2


# Directory to save class images
class_images_dir = Path(prior_preservation_class_folder)

# Create the directory if it doesn't exist
class_images_dir.mkdir(parents=True, exist_ok=True)

# Count existing class images
cur_class_images = len(list(class_images_dir.iterdir()))

# Generate class images if the current number is less than the desired number
if cur_class_images < num_class_images:
    # Load the Stable Diffusion pipeline for image generation
    pipeline = StableDiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path, revision="fp16", torch_dtype=torch.float16
    ).to("cuda")
    pipeline.enable_attention_slicing()
    pipeline.set_progress_bar_config(disable=True)

    # Calculate the number of new images to generate
    num_new_images = num_class_images - cur_class_images
    print(f"Number of class images to sample: {num_new_images}.")

    # Create a dataset and dataloader for generating images
    sample_dataset = PromptDataset(prior_preservation_class_prompt, num_new_images)
    sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)

    # Generate and save the images
    for example in tqdm(sample_dataloader, desc="Generating class images"):
        images = pipeline(example["prompt"]).images

        for i, image in enumerate(images):
            # Save the image to the class images directory
            image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")

    # Clean up the pipeline and free memory
    del pipeline
    gc.collect()
    with torch.no_grad():
      torch.cuda.empty_cache()

# Update the number of class images after generation
num_class_images = len(list(class_images_dir.iterdir()))

# Print the total number of class images
print(f"Total number of class images: {num_class_images}")

Next, we load ip the key pre-trained components of Stable Diffusion that DreamBooth training relies on:

- the **CLIP tokenizer and text encoder** to process and embed text prompts,
- the **VAE (Variational Autoencoder)** to map images to and from a compact latent space for efficient training and generation,
- and the **U-Net model** that performs the core denoising steps to transform random noise into an image guided by the text embedding

These modules form the backbone of Stable Diffusion, and DreamBooth fine-tunes primarily the U-Net while using the other components to condition and reconstruct images from your prompts.

## U-Net

The **U-Net model** is the part of the system that **actually shapes the noise in latent space into a meaningful image** based on your text prompt.

Here’s how it works:

1. **Starting Point – Random Noise:** Generation begins with pure noise in the latent space (a chaotic mess of pixels, but compressed).
2. **U-Net’s Job – Denoising:** The U-Net is a neural network trained to repeatedly look at this noise and figure out what should stay and what should be removed, gradually sculpting an image that matches your prompt.
3. **Guidance:** It uses information from the text encoder (your prompt turned into numbers) to decide how the image should evolve at each step.

It’s called **U-Net** because of its **U-shaped architecture**:

* It has a **downsampling path** (analyzes the noisy image at multiple scales, like zooming in and out).
* A **bottleneck** (understands global context).
* An **upsampling path** (rebuilds the image, refining it step by step).

Imagine sculpting a statue from a rough block of stone. The U-Net is the sculptor — it removes the random noise bit by bit until a detailed, prompt-matching image emerges in latent space.

## Variational Autoencoder

A **VAE (Variational Autoencoder)** is the component responsible for **translating images to and from latent space**:

* **Encoder:** Takes a high-resolution image and compresses it into a **latent representation** (a much smaller, information-rich version of the image).
* **Decoder:** Takes that latent representation and reconstructs it back into a full-resolution image that humans can see.

The reason this matters is efficiency:

* Instead of doing heavy computations on millions of pixels, Stable Diffusion works in the compressed **latent space**, where generating an image is faster and requires less GPU memory.
* The VAE ensures that this compression and decompression keep the image realistic and detailed, even though it's heavily reduced internally.

Think of the VAE as a **translator between human-readable images and the “secret shorthand language” (latent space)** that Stable Diffusion uses to dream up pictures.



In [None]:
# @title Load the Model

# Load the text encoder model from the pretrained model path and subfolder
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder"
)

# Load the variational autoencoder (VAE) model from the pretrained model path and subfolder
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae"
)

# Load the U-Net model from the pretrained model path and subfolder
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet"
)

# Load the CLIP tokenizer from the pretrained model path and subfolder
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
)

Below, we define a **Namespace object** (`args`) that holds all the configuration settings needed to fine-tune Stable Diffusion with DreamBooth. It specifies key parameters such as which pre-trained model to start from, where to find your subject images and class images, the target image resolution, learning rate, number of training steps, batch sizes, and optimization details (like mixed precision, gradient accumulation, and 8-bit Adam). It also includes prior preservation options to prevent overfitting, seed values for reproducibility, and the output directory where checkpoints and the final trained model will be saved. These settings act as the blueprint for how your DreamBooth training session will run.


In [None]:
# @title Set Up All of the Training Parameters {"display-mode":"form"}

#@title Setting up all training args
from argparse import Namespace
args = Namespace(
    # The name or path of the pretrained Stable Diffusion model to use.
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    # The resolution for training and image generation.
    resolution=vae.sample_size,
    # Whether to center crop the images before resizing.
    center_crop=True,
    # Whether to train the text encoder. Training the text encoder can improve results but requires more memory.
    train_text_encoder=False,
    # The directory containing the instance images (images of your subject).
    instance_data_dir=instance_data_dir,
    # The prompt that describes your subject (e.g., "a photo of sks dog").
    instance_prompt=concept_name,
    # The learning rate for the optimizer.
    learning_rate=5e-06,
    # The total number of training steps.
    max_train_steps=300,
    # Save the model checkpoint every N steps.
    save_steps=50,
    # The batch size for training. Set to 1 when using prior preservation.
    train_batch_size=2,
    # Number of updates steps to accumulate before performing a backward/update pass.
    gradient_accumulation_steps=2,
    # Maximum gradient norm for gradient clipping.
    max_grad_norm=1.0,
    # Whether to use mixed precision training ("fp16" or "bf16"). "fp16" is recommended for faster training and lower memory usage.
    mixed_precision="fp16",
    # Whether to use gradient checkpointing to save memory.
    gradient_checkpointing=True,
    # Whether to use 8-bit Adam optimizer from bitsandbytes for lower memory usage.
    use_8bit_adam=True,
    # The seed for reproducible training.
    seed=3434554,
    # Whether to use prior preservation loss. This helps prevent the model from overfitting to the instance images.
    with_prior_preservation=True,
    # The weight of the prior preservation loss.
    prior_loss_weight=prior_loss_weight,
    # The batch size for generating class images (used for prior preservation).
    sample_batch_size=2,
    # The directory containing the class images (images of the subject's class, but not the subject itself).
    class_data_dir=prior_preservation_class_folder,
    # The prompt that describes the class of your subject (e.g., "a photo of a dog").
    class_prompt=prior_preservation_class_prompt,
    # The number of class images to generate or use for prior preservation.
    num_class_images=num_class_images,
    # The learning rate scheduler to use.
    lr_scheduler="constant",
    # The number of steps for the learning rate warmup.
    lr_warmup_steps=100,
    # The directory to save the trained model and checkpoints.
    output_dir="dreambooth-concept",
)

## Defining the Training Function

The **training\_function**, which runs the full DreamBooth fine-tuning process for Stable Diffusion. It prepares the models, optimizer, data, and scheduler, then executes a multi-step training loop. The function freezes non-trainable components (like the VAE), sets up gradient accumulation and checkpointing to manage memory, and builds a dataset of subject and class images. During training, it repeatedly encodes images into latents, adds noise, predicts the noise using the U-Net (optionally with text encoder training), calculates loss (including prior preservation), backpropagates gradients, and updates model weights. Progress and losses are logged, checkpoints are saved at intervals, and the final fine-tuned pipeline is stored at the end.


In [None]:
# @title Training Function {"display-mode":"form"}

from accelerate.utils import set_seed
import bitsandbytes as bnb # Import bitsandbytes with the alias bnb
import torch # Import torch here as well, just in case

def training_function(text_encoder, vae, unet):
    # Get a logger for logging training progress
    logger = get_logger(__name__)

    # Set the random seed for reproducibility
    set_seed(args.seed)

    # Initialize Accelerator for distributed training and mixed precision
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
    )

    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
        raise ValueError(
            "Gradient accumulation is not supported when training the text encoder in distributed training. "
            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
        )

    # Freeze the VAE model as it's not being trained
    vae.requires_grad_(False)
    # Freeze the text encoder if not training it
    if not args.train_text_encoder:
        text_encoder.requires_grad_(False)

    # Enable gradient checkpointing to save memory during training
    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if args.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    # Determine which parameters to optimize (unet, or unet and text_encoder)
    params_to_optimize = (
        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
    )

    # Initialize the optimizer
    optimizer = optimizer_class(
        params_to_optimize,
        lr=args.learning_rate,
    )

    # Load the noise scheduler from the pretrained model config
    noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")

    # Ensure args.num_class_images is updated with the actual number of generated class images
    # Moved this update to before the DreamBoothDataset is initialized
    args.num_class_images = len(list(Path(args.class_data_dir).iterdir()))

    # Create the DreamBooth dataset
    train_dataset = DreamBoothDataset(
        instance_data_dir=args.instance_data_dir,
        instance_prompt=args.instance_prompt,
        class_data_root=args.class_data_dir if args.with_prior_preservation else None,
        class_prompt=args.class_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
    )

    # Define the collation function for the dataloader
    def collate_fn(examples):
        # Extract input IDs and pixel values from examples
        input_ids = [example["instance_prompt_ids"] for example in examples]
        pixel_values = [example["instance_images"] for example in examples]

        # concat class and instance examples for prior preservation
        if args.with_prior_preservation:
            input_ids += [example["class_prompt_ids"] for example in examples]
            pixel_values += [example["class_images"] for example in examples]

        # Stack pixel values into a single tensor
        pixel_values = torch.stack(pixel_values)
        # Convert pixel values to contiguous format and float type
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        # Pad input IDs to the maximum length
        input_ids = tokenizer.pad(
            {"input_ids": input_ids},
            padding="max_length",
            return_tensors="pt",
            max_length=tokenizer.model_max_length
        ).input_ids

        # Create a batch dictionary
        batch = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
        }
        return batch

    # Create the training dataloader
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
    )

    # Initialize the learning rate scheduler
    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    # Prepare models, optimizer, and dataloader for distributed training with Accelerator
    if args.train_text_encoder:
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            unet, text_encoder, optimizer, train_dataloader, lr_scheduler
        )
    else:
        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            unet, optimizer, train_dataloader, lr_scheduler
        )

    # Determine the weight data type for mixed precision training
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move text_encoder and vae to gpu.
    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    vae.to(accelerator.device, dtype=weight_dtype)
    vae.decoder.to("cpu") # Move VAE decoder to CPU to save GPU memory
    if not args.train_text_encoder:
        text_encoder.to(accelerator.device, dtype=weight_dtype)


    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    # Log training information
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")
    global_step = 0

    # Start the training loop
    for epoch in range(num_train_epochs):
        # Set the unet (and text_encoder if training) to training mode
        unet.train()
        if args.train_text_encoder:
            text_encoder.train()
        # Iterate over the training dataloader
        for step, batch in enumerate(train_dataloader):
            # Accumulate gradients over gradient_accumulation_steps
            with accelerator.accumulate(unet):
                # Convert images to latent space using the VAE encoder
                latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                # Scale latents by a constant factor
                latents = latents * 0.18215

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning from the text encoder
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]

                # Predict the noise residual using the U-Net model
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise # Target is the noise itself for epsilon prediction
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps) # Target is the velocity for v_prediction
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # Compute loss, including prior preservation loss if enabled
                if args.with_prior_preservation:
                    # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
                    # The first half is for instance images, the second half is for class images.
                    noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
                    target, target_prior = torch.chunk(target, 2, dim=0)

                    # Compute instance loss (loss on your subject's images)
                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

                    # Compute prior loss (loss on the class images)
                    prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")

                    # Add the prior loss to the instance loss with a specified weight.
                    # This helps the model retain knowledge of the class while learning the instance.
                    loss = loss + args.prior_loss_weight * prior_loss
                else:
                    # If prior preservation is not used, the loss is just the MSE between predicted and target noise
                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")

                # Perform backward pass to compute gradients
                accelerator.backward(loss)

                # Clip gradients to prevent exploding gradients
                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(unet.parameters(), text_encoder.parameters())
                        if args.train_text_encoder
                        else unet.parameters()
                    )
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                # Perform optimizer step to update model weights
                optimizer.step()
                # Zero out gradients after the optimizer step
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                # Update the progress bar and global step counter
                progress_bar.update(1)
                global_step += 1

                # Save model checkpoint at specified intervals
                if global_step % args.save_steps == 0:
                    if accelerator.is_main_process:
                        # Create a StableDiffusionPipeline from the trained models
                        pipeline = StableDiffusionPipeline.from_pretrained(
                            args.pretrained_model_name_or_path,
                            unet=accelerator.unwrap_model(unet), # Unwrap the accelerated UNet model
                            text_encoder=accelerator.unwrap_model(text_encoder), # Unwrap the accelerated Text Encoder model
                        )
                        # Define the save path for the checkpoint
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        # Save the pipeline
                        pipeline.save_pretrained(save_path)

            # Log the current loss
            logs = {"loss": loss.detach().item()}
            # Update the progress bar with the current loss
            progress_bar.set_postfix(**logs)

            # Break the loop if the maximum training steps are reached
            if global_step >= args.max_train_steps:
                break

        # Wait for all processes to finish in distributed training
        accelerator.wait_for_everyone()

    # Create the pipeline using using the trained modules and save it.
    if accelerator.is_main_process:
        # Create a StableDiffusionPipeline from the trained models
        pipeline = StableDiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            unet=accelerator.unwrap_model(unet), # Unwrap the accelerated UNet model
            text_encoder=accelerator.unwrap_model(text_encoder), # Unwrap the accelerated Text Encoder model
        )
        # Save the final trained pipeline
        pipeline.save_pretrained(args.output_dir)

# Run the Training Process

Check whether **prior preservation** is enabled and ensures the required class images exist, generating any missing ones using a pre-trained Stable Diffusion pipeline and the provided class prompt. After confirming the dataset is complete, it prints diagnostic information, then calls `accelerate.notebook_launcher` to execute the `training_function`, which runs the actual fine-tuning of Stable Diffusion on your subject images. Once training finishes, it clears gradients and GPU memory to free up resources.

In [None]:
#@title Run training

import accelerate
import bitsandbytes as bnb # Import bitsandbytes with the alias bnb
import torch # Import torch here as well, just in case
from pathlib import Path
from diffusers import StableDiffusionPipeline # Import StableDiffusionPipeline for image generation
import os # Import os to list directory contents


# Ensure class images are generated before training
if args.with_prior_preservation:
    class_images_dir = Path(args.class_data_dir)
    if not class_images_dir.exists():
        class_images_dir.mkdir(parents=True)

    cur_class_images = len(list(class_images_dir.iterdir()))
    if cur_class_images < args.num_class_images:
        pipeline = StableDiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path, revision="fp16", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.enable_attention_slicing()
        pipeline.set_progress_bar_config(disable=True)

        num_new_images = args.num_class_images - cur_class_images
        print(f"Number of class images to sample: {num_new_images}.")

        sample_dataset = PromptDataset(args.class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

        for example in tqdm(sample_dataloader, desc="Generating class images"):
            images = pipeline(example["prompt"]).images

            for i, image in enumerate(images):
                image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
        pipeline = None
        gc.collect()
        del pipeline
        with torch.no_grad():
          torch.cuda.empty_cache()

    # Update args.num_class_images with the actual number of generated images
    args.num_class_images = len(list(class_images_dir.iterdir()))
    print(f"Total number of class images: {args.num_class_images}")

# Add diagnostic print statements
print(f"Contents of {args.class_data_dir}: {os.listdir(args.class_data_dir)}")
print(f"args.num_class_images before training: {args.num_class_images}")


accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
  if param.grad is not None:
    del param.grad  # free some memory
  torch.cuda.empty_cache()

# Use the Fine-Tuned Model

And now, we're going to set up the pipelines for the new fine-tuned model.

In [None]:
from diffusers import DiffusionPipeline

# Load the pipeline from the local output directory
model_path = args.output_dir # Use the output directory from the training step
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to("cuda")

## Generate New Images

The moment we've been waiting for. Use your trained model to create new images of your subject.

In [None]:
suffix = "in space" # @param {"type":"string"}
# prompt = "a photo of" + concept_name + " " + suffix # Original line
prompt = f"a photo of {concept_name} {suffix}" # Using f-string for interpolation

# Generate the image
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

# Display the image
image