In [1]:
import argparse
import copy
import itertools
import logging
import math
import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path

import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast

import diffusers
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    SD3Transformer2DModel,
    StableDiffusion3Pipeline,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
    _set_state_dict_into_text_encoder,
    cast_training_params,
    compute_density_for_timestep_sampling,
    compute_loss_weighting_for_sd3,
    free_memory,
)
from diffusers.utils import (
    check_min_version,
    convert_unet_state_dict_to_peft,
    is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module

if is_wandb_available():
    import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import sys
import warnings

# Ignore warnings
warnings.filterwarnings("ignore")

# Simulate command-line arguments
sys.argv = [
    "notebook_name",
    "--train_text_encoder",
    "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3.5-medium",
    "--instance_data_dir", "dog",
    "--output_dir", "trained-sd3-lora",
    "--mixed_precision", "fp16",
    "--instance_prompt", "a photo of sks dog",
    "--resolution", "512",
    "--train_batch_size", "1",
    "--gradient_accumulation_steps", "4",
    "--learning_rate", "4e-4",
    "--report_to", "tensorboard",
    "--lr_scheduler", "constant",
    "--lr_warmup_steps", "0",
    "--max_train_steps", "500",
    "--validation_prompt", "A photo of sks dog in a bucket",
    "--validation_epochs", "25",
    "--seed", "0",
    "--push_to_hub",
]

from args import parse_args

# Parse the arguments
args = parse_args()

# Use the arguments
print(vars(args))

{'pretrained_model_name_or_path': 'stabilityai/stable-diffusion-3.5-medium', 'revision': None, 'variant': None, 'dataset_name': None, 'dataset_config_name': None, 'instance_data_dir': 'dog', 'cache_dir': None, 'image_column': 'image', 'caption_column': None, 'repeats': 1, 'class_data_dir': None, 'instance_prompt': 'a photo of sks dog', 'class_prompt': None, 'max_sequence_length': 77, 'validation_prompt': 'A photo of sks dog in a bucket', 'num_validation_images': 4, 'validation_epochs': 25, 'rank': 4, 'with_prior_preservation': False, 'prior_loss_weight': 1.0, 'num_class_images': 100, 'output_dir': 'trained-sd3-lora', 'seed': 0, 'resolution': 512, 'center_crop': False, 'random_flip': False, 'train_text_encoder': False, 'train_batch_size': 1, 'sample_batch_size': 4, 'num_train_epochs': 1, 'max_train_steps': 500, 'checkpointing_steps': 500, 'checkpoints_total_limit': None, 'resume_from_checkpoint': None, 'gradient_accumulation_steps': 4, 'gradient_checkpointing': False, 'learning_rate': 0

In [3]:
check_min_version("0.32.0.dev0")

logger = get_logger(__name__)

In [4]:
class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        class_prompt,
        class_data_root=None,
        class_num=None,
        size=1024,
        repeats=1,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop

        self.instance_prompt = instance_prompt
        self.custom_instance_prompts = None
        self.class_prompt = class_prompt

        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
        # we load the training data using load_dataset
        if args.dataset_name is not None:
            try:
                from datasets import load_dataset
            except ImportError:
                raise ImportError(
                    "You are trying to load your data using the datasets library. If you wish to train using custom "
                    "captions please install the datasets library: `pip install datasets`. If you wish to load a "
                    "local folder containing images only, specify --instance_data_dir instead."
                )
            # Downloading and loading a dataset from the hub.
            # See more about loading custom images at
            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
            dataset = load_dataset(
                args.dataset_name,
                args.dataset_config_name,
                cache_dir=args.cache_dir,
            )
            # Preprocessing the datasets.
            column_names = dataset["train"].column_names

            # 6. Get the column names for input/target.
            if args.image_column is None:
                image_column = column_names[0]
                logger.info(f"image column defaulting to {image_column}")
            else:
                image_column = args.image_column
                if image_column not in column_names:
                    raise ValueError(
                        f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
                    )
            instance_images = dataset["train"][image_column]

            if args.caption_column is None:
                logger.info(
                    "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
                    "contains captions/prompts for the images, make sure to specify the "
                    "column as --caption_column"
                )
                self.custom_instance_prompts = None
            else:
                if args.caption_column not in column_names:
                    raise ValueError(
                        f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
                    )
                custom_instance_prompts = dataset["train"][args.caption_column]
                # create final list of captions according to --repeats
                self.custom_instance_prompts = []
                for caption in custom_instance_prompts:
                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
        else:
            self.instance_data_root = Path(instance_data_root)
            if not self.instance_data_root.exists():
                raise ValueError("Instance images root doesn't exists.")

            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
            self.custom_instance_prompts = None

        self.instance_images = []
        for img in instance_images:
            self.instance_images.extend(itertools.repeat(img, repeats))

        self.pixel_values = []
        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
        train_flip = transforms.RandomHorizontalFlip(p=1.0)
        train_transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
        for image in self.instance_images:
            image = exif_transpose(image)
            if not image.mode == "RGB":
                image = image.convert("RGB")
            image = train_resize(image)
            if args.random_flip and random.random() < 0.5:
                # flip
                image = train_flip(image)
            if args.center_crop:
                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
                image = train_crop(image)
            else:
                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
                image = crop(image, y1, x1, h, w)
            image = train_transforms(image)
            self.pixel_values.append(image)

        self.num_instance_images = len(self.instance_images)
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = self.pixel_values[index % self.num_instance_images]
        example["instance_images"] = instance_image

        if self.custom_instance_prompts:
            caption = self.custom_instance_prompts[index % self.num_instance_images]
            if caption:
                example["instance_prompt"] = caption
            else:
                example["instance_prompt"] = self.instance_prompt

        else:  # custom prompts were provided, but length does not match size of image dataset
            example["instance_prompt"] = self.instance_prompt

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            class_image = exif_transpose(class_image)

            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt"] = self.class_prompt

        return example

def collate_fn(examples, with_prior_preservation=False):
    pixel_values = [example["instance_images"] for example in examples]
    prompts = [example["instance_prompt"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        pixel_values += [example["class_images"] for example in examples]
        prompts += [example["class_prompt"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    batch = {"pixel_values": pixel_values, "prompts": prompts}
    return batch

class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example

In [5]:
if args.report_to == "wandb" and args.hub_token is not None:
    raise ValueError(
        "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
        " Please use `huggingface-cli login` to authenticate with the Hub."
    )

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
    # due to pytorch#99272, MPS does not yet support bfloat16.
    raise ValueError(
        "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
    )

logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_config=accelerator_project_config,
    kwargs_handlers=[kwargs],
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
    accelerator.native_amp = False

if args.report_to == "wandb":
    if not is_wandb_available():
        raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
11/19/2024 16:24:54 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [6]:
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
    class_images_dir = Path(args.class_data_dir)
    if not class_images_dir.exists():
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))

    if cur_class_images < args.num_class_images:
        has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
        torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
        if args.prior_generation_precision == "fp32":
            torch_dtype = torch.float32
        elif args.prior_generation_precision == "fp16":
            torch_dtype = torch.float16
        elif args.prior_generation_precision == "bf16":
            torch_dtype = torch.bfloat16
        pipeline = StableDiffusion3Pipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            revision=args.revision,
            variant=args.variant,
        )
        pipeline.set_progress_bar_config(disable=True)

        num_new_images = args.num_class_images - cur_class_images
        logger.info(f"Number of class images to sample: {num_new_images}.")

        sample_dataset = PromptDataset(args.class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

        sample_dataloader = accelerator.prepare(sample_dataloader)
        pipeline.to(accelerator.device)

        for example in tqdm(
            sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
        ):
            images = pipeline(example["prompt"]).images

            for i, image in enumerate(images):
                hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
                image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                image.save(image_filename)

        del pipeline
        free_memory()

In [43]:
# Handle the repository creation
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    if args.push_to_hub:
        repo_id = create_repo(
            repo_id=args.hub_model_id or Path(args.output_dir).name,
            exist_ok=True,
        ).repo_id

In [7]:
# Load tokenizers
from models import load_tokenizers

tokenizer_one, tokenizer_two, tokenizer_three = load_tokenizers(args)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [8]:
# Import correct text encoder
from models import import_model_class_from_model_name_or_path

text_encoder_cls_one = import_model_class_from_model_name_or_path(
    args.pretrained_model_name_or_path, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
    args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
)
text_encoder_cls_three = import_model_class_from_model_name_or_path(
    args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
)

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


In [9]:
# Load scheduler and models
from models import load_text_encoders

noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
    text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, args
)
vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="vae",
    revision=args.revision,
    variant=args.variant,
)

{'base_shift', 'invert_sigmas', 'base_image_seq_len', 'use_dynamic_shifting', 'max_shift', 'max_image_seq_len'} was not found in config. Values will be initialized to default values.
Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 3315.66it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


In [10]:
transformer = SD3Transformer2DModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
)

In [11]:
transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
text_encoder_three.requires_grad_(False)

T5EncoderModel(
  (shared): Embedding(32128, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo

In [12]:
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
    # due to pytorch#99272, MPS does not yet support bfloat16.
    raise ValueError(
        "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
    )

In [13]:
print("accelerator.device", accelerator.device)
vae.to(accelerator.device, dtype=torch.float32)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
text_encoder_three.to(accelerator.device, dtype=weight_dtype)

accelerator.device cuda


T5EncoderModel(
  (shared): Embedding(32128, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
              (wo

In [14]:
if args.gradient_checkpointing:
    transformer.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder_one.gradient_checkpointing_enable()
        text_encoder_two.gradient_checkpointing_enable()
if args.lora_layers is not None:
    target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
    target_modules = [
        "attn.add_k_proj",
        "attn.add_q_proj",
        "attn.add_v_proj",
        "attn.to_add_out",
        "attn.to_k",
        "attn.to_out.0",
        "attn.to_q",
        "attn.to_v",
    ]

if args.lora_blocks is not None:
    target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
    target_modules = [
        f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
    ]

In [15]:
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
    r=args.rank,
    lora_alpha=args.rank,
    init_lora_weights="gaussian",
    target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)

In [16]:
if args.train_text_encoder:
    text_lora_config = LoraConfig(
        r=args.rank,
        lora_alpha=args.rank,
        init_lora_weights="gaussian",
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
    )
    text_encoder_one.add_adapter(text_lora_config)
    text_encoder_two.add_adapter(text_lora_config)

In [17]:
def unwrap_model(model):
    model = accelerator.unwrap_model(model)
    model = model._orig_mod if is_compiled_module(model) else model
    return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        transformer_lora_layers_to_save = None
        text_encoder_one_lora_layers_to_save = None
        text_encoder_two_lora_layers_to_save = None

        for model in models:
            if isinstance(model, type(unwrap_model(transformer))):
                transformer_lora_layers_to_save = get_peft_model_state_dict(model)
            elif isinstance(model, type(unwrap_model(text_encoder_one))):
                text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
            elif isinstance(model, type(unwrap_model(text_encoder_two))):
                text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
            else:
                raise ValueError(f"unexpected save model: {model.__class__}")

            # make sure to pop weight so that corresponding model is not saved again
            weights.pop()

        StableDiffusion3Pipeline.save_lora_weights(
            output_dir,
            transformer_lora_layers=transformer_lora_layers_to_save,
            text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
            text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
        )

def load_model_hook(models, input_dir):
    transformer_ = None
    text_encoder_one_ = None
    text_encoder_two_ = None

    while len(models) > 0:
        model = models.pop()

        if isinstance(model, type(unwrap_model(transformer))):
            transformer_ = model
        elif isinstance(model, type(unwrap_model(text_encoder_one))):
            text_encoder_one_ = model
        elif isinstance(model, type(unwrap_model(text_encoder_two))):
            text_encoder_two_ = model
        else:
            raise ValueError(f"unexpected save model: {model.__class__}")

    lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

    transformer_state_dict = {
        f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
    }
    transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
    incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
    if incompatible_keys is not None:
        # check only for unexpected keys
        unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
        if unexpected_keys:
            logger.warning(
                f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                f" {unexpected_keys}. "
            )
    if args.train_text_encoder:
        # Do we need to call `scale_lora_layers()` here?
        _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

        _set_state_dict_into_text_encoder(
            lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
        )

    # Make sure the trainable params are in float32. This is again needed since the base models
    # are in `weight_dtype`. More details:
    # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
    if args.mixed_precision == "fp16":
        models = [transformer_]
        if args.train_text_encoder:
            models.extend([text_encoder_one_, text_encoder_two_])
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params(models)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

<torch.utils.hooks.RemovableHandle at 0x7f956fbc82e0>

In [18]:
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

In [19]:
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
    models = [transformer]
    if args.train_text_encoder:
        models.extend([text_encoder_one, text_encoder_two])
    # only upcast trainable parameters (LoRA) into fp32
    cast_training_params(models, dtype=torch.float32)

In [20]:
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
    text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
    text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))

In [21]:
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
    # different learning rate for text encoder and unet
    text_lora_parameters_one_with_lr = {
        "params": text_lora_parameters_one,
        "weight_decay": args.adam_weight_decay_text_encoder,
        "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
    }
    text_lora_parameters_two_with_lr = {
        "params": text_lora_parameters_two,
        "weight_decay": args.adam_weight_decay_text_encoder,
        "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
    }
    params_to_optimize = [
        transformer_parameters_with_lr,
        text_lora_parameters_one_with_lr,
        text_lora_parameters_two_with_lr,
    ]
else:
    params_to_optimize = [transformer_parameters_with_lr]

In [22]:
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
    logger.warning(
        f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
        "Defaulting to adamW"
    )
    args.optimizer = "adamw"

if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
    logger.warning(
        f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
        f"set to {args.optimizer.lower()}"
    )

if args.optimizer.lower() == "adamw":
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    optimizer = optimizer_class(
        params_to_optimize,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

if args.optimizer.lower() == "prodigy":
    try:
        import prodigyopt
    except ImportError:
        raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")

    optimizer_class = prodigyopt.Prodigy

    if args.learning_rate <= 0.1:
        logger.warning(
            "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
        )
    if args.train_text_encoder and args.text_encoder_lr:
        logger.warning(
            f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
            f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
            f"When using prodigy only learning_rate is used as the initial learning rate."
        )
        # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
        # --learning_rate
        params_to_optimize[1]["lr"] = args.learning_rate
        params_to_optimize[2]["lr"] = args.learning_rate

    optimizer = optimizer_class(
        params_to_optimize,
        betas=(args.adam_beta1, args.adam_beta2),
        beta3=args.prodigy_beta3,
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
        decouple=args.prodigy_decouple,
        use_bias_correction=args.prodigy_use_bias_correction,
        safeguard_warmup=args.prodigy_safeguard_warmup,
    )

In [23]:
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
    instance_data_root=args.instance_data_dir,
    instance_prompt=args.instance_prompt,
    class_prompt=args.class_prompt,
    class_data_root=args.class_data_dir if args.with_prior_preservation else None,
    class_num=args.num_class_images,
    size=args.resolution,
    repeats=args.repeats,
    center_crop=args.center_crop,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
    num_workers=args.dataloader_num_workers,
)

In [24]:
from models import encode_prompt

if not args.train_text_encoder:
    tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
    text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]

    def compute_text_embeddings(prompt, text_encoders, tokenizers):
        with torch.no_grad():
            prompt_embeds, pooled_prompt_embeds = encode_prompt(
                text_encoders, tokenizers, prompt, args.max_sequence_length
            )
            prompt_embeds = prompt_embeds.to(accelerator.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
        return prompt_embeds, pooled_prompt_embeds

In [25]:
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
    instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
        args.instance_prompt, text_encoders, tokenizers
    )

# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
    if not args.train_text_encoder:
        class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
            args.class_prompt, text_encoders, tokenizers
        )

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
    # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
    del tokenizers, text_encoders
    del text_encoder_one, text_encoder_two, text_encoder_three
    free_memory()

In [26]:
from models import tokenize_prompt
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.

if not train_dataset.custom_instance_prompts:
    if not args.train_text_encoder:
        prompt_embeds = instance_prompt_hidden_states
        pooled_prompt_embeds = instance_pooled_prompt_embeds
        if args.with_prior_preservation:
            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
            pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
    # batch prompts on all training steps
    else:
        tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
        tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
        tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)
        if args.with_prior_preservation:
            class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
            class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
            class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)
            tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
            tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
            tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)


In [27]:
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
if args.cache_latents:
    latents_cache = []
    for batch in tqdm(train_dataloader, desc="Caching latents"):
        with torch.no_grad():
            batch["pixel_values"] = batch["pixel_values"].to(
                accelerator.device, non_blocking=True, dtype=weight_dtype
            )
            latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

    if args.validation_prompt is None:
        del vae
        free_memory()

In [28]:
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

In [29]:
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
    (
        transformer,
        text_encoder_one,
        text_encoder_two,
        optimizer,
        train_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
    )
    assert text_encoder_one is not None
    assert text_encoder_two is not None
    assert text_encoder_three is not None
else:
    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        transformer, optimizer, train_dataloader, lr_scheduler
    )

In [30]:
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    tracker_name = "dreambooth-sd3-lora"
    accelerator.init_trackers(tracker_name, config=vars(args))

In [31]:
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

11/19/2024 16:25:21 - INFO - __main__ - ***** Running training *****
11/19/2024 16:25:21 - INFO - __main__ -   Num examples = 5
11/19/2024 16:25:21 - INFO - __main__ -   Num batches each epoch = 5
11/19/2024 16:25:21 - INFO - __main__ -   Num Epochs = 250
11/19/2024 16:25:21 - INFO - __main__ -   Instantaneous batch size per device = 1
11/19/2024 16:25:21 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
11/19/2024 16:25:21 - INFO - __main__ -   Gradient Accumulation steps = 4
11/19/2024 16:25:21 - INFO - __main__ -   Total optimization steps = 500


In [32]:
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the mos recent checkpoint
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch

else:
    initial_global_step = 0

In [33]:
progress_bar = tqdm(
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

Steps:   0%|          | 0/500 [00:00<?, ?it/s]

In [34]:
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
    sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
    schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
    timesteps = timesteps.to(accelerator.device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

In [39]:
def log_validation(
    pipeline,
    args,
    accelerator,
    pipeline_args,
    epoch,
    torch_dtype,
    is_final_validation=False,
    logger=None,
):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
        f" {args.validation_prompt}."
    )
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    # run inference
    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
    autocast_ctx = nullcontext()

    with autocast_ctx:
        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

    for tracker in accelerator.trackers:
        phase_name = "test" if is_final_validation else "validation"
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
        if tracker.name == "wandb":
            tracker.log(
                {
                    phase_name: [
                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
                    ]
                }
            )

    del pipeline
    free_memory()

    return images

In [48]:
from itertools import accumulate


for epoch in range(first_epoch, args.num_train_epochs):
    transformer.train()
    if args.train_text_encoder:
        print("Training text encoders")
        text_encoder_one.train()
        text_encoder_two.train()
        
        # set top parameter requires_grad = True for gradient checkpointing works
        accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
        accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)

    # accumulated_gradients = {name: torch.zeros_like(param) for name, param in transformer.named_parameters()}
    
    for step, batch in enumerate(train_dataloader):
        models_to_accumulate = [transformer]
        if args.train_text_encoder:
            models_to_accumulate.extend([text_encoder_one, text_encoder_two])
        with accelerator.accumulate(models_to_accumulate):
            prompts = batch["prompts"]

            # encode batch prompts when custom prompts are provided for each image -
            if train_dataset.custom_instance_prompts:
                if not args.train_text_encoder:
                    prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
                        prompts, text_encoders, tokenizers
                    )
                else:
                    tokens_one = tokenize_prompt(tokenizer_one, prompts)
                    tokens_two = tokenize_prompt(tokenizer_two, prompts)
                    tokens_three = tokenize_prompt(tokenizer_three, prompts)
                    prompt_embeds, pooled_prompt_embeds = encode_prompt(
                        text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
                        tokenizers=[None, None, None],
                        prompt=prompts,
                        max_sequence_length=args.max_sequence_length,
                        text_input_ids_list=[tokens_one, tokens_two, tokens_three],
                    )
            else:
                if args.train_text_encoder:
                    prompt_embeds, pooled_prompt_embeds = encode_prompt(
                        text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
                        tokenizers=[None, None, tokenizer_three],
                        prompt=args.instance_prompt,
                        max_sequence_length=args.max_sequence_length,
                        text_input_ids_list=[tokens_one, tokens_two, tokens_three],
                    )

            # Convert images to latent space
            if args.cache_latents:
                model_input = latents_cache[step].sample()
            else:
                pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
                model_input = vae.encode(pixel_values).latent_dist.sample()

            model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
            model_input = model_input.to(dtype=weight_dtype)

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(model_input)
            bsz = model_input.shape[0]

            # Sample a random timestep for each image
            # for weighting schemes where we sample timesteps non-uniformly
            u = compute_density_for_timestep_sampling(
                weighting_scheme=args.weighting_scheme,
                batch_size=bsz,
                logit_mean=args.logit_mean,
                logit_std=args.logit_std,
                mode_scale=args.mode_scale,
            )
            indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
            timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

            # Add noise according to flow matching.
            # zt = (1 - texp) * x + texp * z1
            sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
            noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

            # Predict the noise residual
            model_pred = transformer(
                hidden_states=noisy_model_input,
                timestep=timesteps,
                encoder_hidden_states=prompt_embeds,
                pooled_projections=pooled_prompt_embeds,
                return_dict=False,
            )[0]

            # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
            # Preconditioning of the model outputs.
            if args.precondition_outputs:
                model_pred = model_pred * (-sigmas) + noisy_model_input

            # these weighting schemes use a uniform timestep sampling
            # and instead post-weight the loss
            weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)

            # flow matching loss
            if args.precondition_outputs:
                target = model_input
            else:
                target = noise - model_input

            if args.with_prior_preservation:
                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute prior loss
                prior_loss = torch.mean(
                    (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
                        target_prior.shape[0], -1
                    ),
                    1,
                )
                prior_loss = prior_loss.mean()

            # Compute regular loss.
            loss = torch.mean(
                (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
                1,
            )
            loss = loss.mean()

            if args.with_prior_preservation:
                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss

            accelerator.backward(loss)
            if accelerator.is_main_process:
                gradients = {}
                def save_model_gradients(model, model_name):
                    for name, param in model.named_parameters():
                        if param.grad is not None:
                            gradients[f"{model_name}.{name}"] = param.grad.clone().detach().cpu()
                save_model_gradients(transformer, "transformer")
                if args.train_text_encoder:
                    save_model_gradients(text_encoder_one, "text_encoder_one")
                    save_model_gradients(text_encoder_two, "text_encoder_two")
                torch.save(gradients, f"gradients_{step}.pt")
                break
            
            if accelerator.sync_gradients:
                params_to_clip = (
                    itertools.chain(
                        transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
                    )
                    if args.train_text_encoder
                    else transformer_lora_parameters
                )
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if accelerator.is_main_process:
                if global_step % args.checkpointing_steps == 0:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        checkpoints = os.listdir(args.output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= args.checkpoints_total_limit:
                            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")
    
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

        if global_step >= args.max_train_steps:
            break

    # if accelerator.is_main_process:
    #     if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
    #         if not args.train_text_encoder:
    #             # create pipeline
    #             text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
    #                 text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three, args
    #             )
    #             text_encoder_one.to(weight_dtype)
    #             text_encoder_two.to(weight_dtype)
    #         pipeline = StableDiffusion3Pipeline.from_pretrained(
    #             args.pretrained_model_name_or_path,
    #             vae=vae,
    #             text_encoder=accelerator.unwrap_model(text_encoder_one),
    #             text_encoder_2=accelerator.unwrap_model(text_encoder_two),
    #             text_encoder_3=accelerator.unwrap_model(text_encoder_three),
    #             transformer=accelerator.unwrap_model(transformer),
    #             revision=args.revision,
    #             variant=args.variant,
    #             torch_dtype=weight_dtype,
    #         )
    #         pipeline_args = {"prompt": args.validation_prompt}

    #         images = log_validation(
    #             pipeline=pipeline,
    #             args=args,
    #             accelerator=accelerator,
    #             pipeline_args=pipeline_args,
    #             epoch=epoch,
    #             torch_dtype=weight_dtype,
    #             logger=logger,
    #         )
    #         if not args.train_text_encoder:
    #             del text_encoder_one, text_encoder_two, text_encoder_three
    #             free_memory()
    
    break

In [45]:
from hf import save_model_card

In [47]:
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    transformer = unwrap_model(transformer)
    if args.upcast_before_saving:
        transformer.to(torch.float32)
    else:
        transformer = transformer.to(weight_dtype)
    transformer_lora_layers = get_peft_model_state_dict(transformer)

    if args.train_text_encoder:
        text_encoder_one = unwrap_model(text_encoder_one)
        text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
        text_encoder_two = unwrap_model(text_encoder_two)
        text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
    else:
        text_encoder_lora_layers = None
        text_encoder_2_lora_layers = None

    StableDiffusion3Pipeline.save_lora_weights(
        save_directory=args.output_dir,
        transformer_lora_layers=transformer_lora_layers,
        text_encoder_lora_layers=text_encoder_lora_layers,
        text_encoder_2_lora_layers=text_encoder_2_lora_layers,
    )

    # Final inference
    # Load previous pipeline
    pipeline = StableDiffusion3Pipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        revision=args.revision,
        variant=args.variant,
        torch_dtype=weight_dtype,
    )
    # load attention processors
    pipeline.load_lora_weights(args.output_dir)

    # run inference
    images = []
    if args.validation_prompt and args.num_validation_images > 0:
        pipeline_args = {"prompt": args.validation_prompt}
        # images = log_validation(
        #     pipeline=pipeline,
        #     args=args,
        #     accelerator=accelerator,
        #     pipeline_args=pipeline_args,
        #     epoch=epoch,
        #     is_final_validation=True,
        #     torch_dtype=weight_dtype,
        # )

    if args.push_to_hub:
        # save_model_card(
        #     repo_id,
        #     images=images,
        #     base_model=args.pretrained_model_name_or_path,
        #     instance_prompt=args.instance_prompt,
        #     validation_prompt=args.validation_prompt,
        #     train_text_encoder=args.train_text_encoder,
        #     repo_folder=args.output_dir,
        # )
        upload_folder(
            repo_id=repo_id,
            folder_path=args.output_dir,
            commit_message="End of training",
            ignore_patterns=["step_*", "epoch_*"],
        )

accelerator.end_training()

Model weights saved in trained-sd3-lora/pytorch_lora_weights.safetensors


{'base_shift', 'invert_sigmas', 'base_image_seq_len', 'use_dynamic_shifting', 'max_shift', 'max_image_seq_len'} was not found in config. Values will be initialized to default values.
Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-3.5-medium.
Loaded vae as AutoencoderKL from `vae` subfolder of stabilityai/stable-diffusion-3.5-medium.
Loaded tokenizer_3 as T5TokenizerFast from `tokenizer_3` subfolder of stabilityai/stable-diffusion-3.5-medium.

[A
[A
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.11it/s]
Loaded text_encoder_3 as T5EncoderModel from `text_encoder_3` subfolder of stabilityai/stable-diffusion-3.5-medium.
Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of stabilityai/stable-diffusion-3.5-medium.
Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-3.5-medium.
Loaded text_encoder as CLIPTextModelWithProjection fro