In [1]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

Tesla V100-PCIE-32GB, 32510 MiB, 17111 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 22419 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 22707 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 24223 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 18100 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 32506 MiB
Tesla V100-PCIE-32GB, 32510 MiB, 30717 MiB


In [2]:
import argparse
import logging
import math
import os
import random
import warnings
from pathlib import Path
from typing import Optional

import numpy as np
import PIL
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami

# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

if is_wandb_available():
    import wandb

if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.14.0.dev0")

logger = get_logger(__name__)

In [4]:
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
    logger.info(
        f"Running validation... \n Generating {args_num_validation_images} images with prompt:"
        f" {args_validation_prompt}."
    )
    # create pipeline (note: unet and vae are loaded again in float32)
    pipeline = DiffusionPipeline.from_pretrained(
        args_pretrained_model_name_or_path,
        text_encoder=accelerator.unwrap_model(text_encoder),
        tokenizer=tokenizer,
        unet=unet,
        vae=vae,
        revision=args_revision,
        torch_dtype=weight_dtype,
    )
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    # run inference
    generator = None if args_seed is None else torch.Generator(
        device=accelerator.device).manual_seed(args_seed)
    images = []
    for _ in range(args_num_validation_images):
        with torch.autocast("cuda"):
            image = pipeline(args_validation_prompt, num_inference_steps=25, 
                             generator=generator).images[0]
        images.append(image)

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
        if tracker.name == "wandb":
            tracker.log(
                {
                    "validation": [
                        wandb.Image(image, caption=f"{i}: {args_validation_prompt}") for i, 
                            image in enumerate(images)
                    ]
                }
            )

    del pipeline
    torch.cuda.empty_cache()


def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
    logger.info("Saving embeddings")
    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
                        placeholder_token_id]
    learned_embeds_dict = {args_placeholder_token: learned_embeds.detach().cpu()}
    torch.save(learned_embeds_dict, save_path)

In [5]:
# Define common settings
args_revision = None  # Revision of pretrained model identifier from huggingface.co/models
args_tokenizer_name = None  # Pretrained tokenizer name or path if not the same as model_name
args_repeats = 100  # How many times to repeat the training data
args_center_crop = False  # Whether to center crop images before resizing to resolution
args_num_train_epochs = 100
args_gradient_checkpointing = False  # Whether or not to use gradient checkpointing to 
                                     # save memory at the expense of slower backward pass
args_dataloader_num_workers = 0
args_adam_beta1 = 0.9
args_adam_beta2 = 0.999
args_adam_weight_decay = 1e-2
args_adam_epsilon = 1e-08
args_push_to_hub = False
args_hub_token = None
args_hub_model_id = None
args_logging_dir = "logs"
args_model_dir = "models"
args_mixed_precision = "no"  # choose from: ["no", "fp16", "bf16"]
args_allow_tf32 = False
args_report_to = "tensorboard"  # choose from: ["tensorboard", "default", "wandb", "comet_ml", 
                                #  "all"]
args_validation_prompt = None
args_num_validation_images = 4
args_validation_steps = 100
args_validation_epochs = None
args_local_rank = -1
args_checkpoints_total_limit = None

In [6]:
# Define special settings for different datasets
args_save_steps = 500  # Save learned_embeds.bin every X updates steps
args_only_save_embeds = True  # Save only the embeddings for the new concept
args_pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
# Path to pretrained model or model identifier from huggingface.co/models
args_train_data_dir = "datasets/colorful_teapot"  # A folder containing the training data
args_placeholder_token = "<colorful_teapot>"  # A token to use as a placeholder for the concept
args_initializer_token = "teapot"  # A token to use as initializer word
args_learnable_property = "object"  # choose between "object" and "style"
args_output_dir = "outputs/colorful_teapot"  # The output directory where the model predictions and 
                                        # checkpoints will be written
args_seed = None  # A seed for reproducible training
args_resolution = 512  # The resolution for input images, all the images in the train/validation 
                       # dataset will be resized to this resolution
args_train_batch_size = 1  # Batch size (per device) for the training dataloader
args_max_train_steps = 3000  # Total number of training steps to perform.  If provided, overrides 
                             # num_train_epochs
args_gradient_accumulation_steps = 4  # Number of updates steps to accumulate before performing a 
                                      # backward/update pass
args_learning_rate = 5.0e-04
args_scale_lr = True
args_lr_scheduler = "constant"  # choose from: ["linear", "cosine", "cosine_with_restarts", 
                               # "polynomial", "constant", "constant_with_warmup"]
args_lr_warmup_steps = 0
args_checkpointing_steps = 500
args_resume_from_checkpoint = None
args_enable_xformers_memory_efficient_attention = False

In [7]:
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args_local_rank:
    args_local_rank = env_local_rank

In [8]:
imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]

In [9]:
class TextualInversionDataset(Dataset):
    def __init__(
            self,
            data_root,
            tokenizer,
            learnable_property="object",  # [object, style]
            size=512,
            repeats=100,
            interpolation="bicubic",
            flip_p=0.5,
            set="train",
            placeholder_token="*",
            center_crop=False,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.learnable_property = learnable_property
        self.size = size
        self.placeholder_token = placeholder_token
        self.center_crop = center_crop
        self.flip_p = flip_p

        self.image_paths = [os.path.join(self.data_root, 
                                         file_path) for file_path in os.listdir(self.data_root)]

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        if set == "train":
            self._length = self.num_images * repeats

        self.interpolation = {
            "linear": PIL_INTERPOLATION["linear"],
            "bilinear": PIL_INTERPOLATION["bilinear"],
            "bicubic": PIL_INTERPOLATION["bicubic"],
            "lanczos": PIL_INTERPOLATION["lanczos"],
        }[interpolation]

        self.templates = imagenet_style_templates_small \
            if learnable_property == "style" else imagenet_templates_small
        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        text = random.choice(self.templates).format(placeholder_string)

        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            (
                h,
                w,
            ) = (
                img.shape[0],
                img.shape[1],
            )
            img = img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

In [10]:
def get_full_repo_name(model_id: str, organization: Optional[str] = None, 
                       token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"

In [14]:
logging_dir = os.path.join(args_output_dir, args_logging_dir)

accelerator_project_config = ProjectConfiguration(total_limit=args_checkpoints_total_limit)

accelerator = Accelerator(
    gradient_accumulation_steps=args_gradient_accumulation_steps,
    mixed_precision=args_mixed_precision,
    log_with=args_report_to,
    logging_dir=logging_dir,
    project_config=accelerator_project_config,
)

if args_report_to == "wandb":
    if not is_wandb_available():
        raise ImportError(
            "Make sure to install wandb if you want to use it for logging during training.")

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args_seed is not None:
    set_seed(args_seed)

03/26/2023 11:41:01 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no



In [15]:
# Handle the repository creation
if accelerator.is_main_process:
    if args_push_to_hub:
        if args_hub_model_id is None:
            repo_name = get_full_repo_name(Path(args_output_dir).name, token=args_hub_token)
        else:
            repo_name = args_hub_model_id
        create_repo(repo_name, exist_ok=True, token=args_hub_token)
        repo = Repository(args_output_dir, clone_from=repo_name, token=args_hub_token)

        with open(os.path.join(args_output_dir, ".gitignore"), "w+") as gitignore:
            if "step_*" not in gitignore:
                gitignore.write("step_*\n")
            if "epoch_*" not in gitignore:
                gitignore.write("epoch_*\n")
    elif args_output_dir is not None:
        os.makedirs(args_output_dir, exist_ok=True)

# Load tokenizer
if args_tokenizer_name:
    tokenizer = CLIPTokenizer.from_pretrained(args_tokenizer_name)
elif args_pretrained_model_name_or_path:
    tokenizer = CLIPTokenizer.from_pretrained(args_pretrained_model_name_or_path, 
                                              subfolder="tokenizer")

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args_pretrained_model_name_or_path, 
                                                subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(
    args_pretrained_model_name_or_path, subfolder="text_encoder", revision=args_revision
)
vae = AutoencoderKL.from_pretrained(args_pretrained_model_name_or_path, subfolder="vae", 
                                    revision=args_revision)
unet = UNet2DConditionModel.from_pretrained(
    args_pretrained_model_name_or_path, subfolder="unet", revision=args_revision
)

# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args_placeholder_token)
if num_added_tokens == 0:
    raise ValueError(
        f"The tokenizer already contains the token {args_placeholder_token}. \
         Please pass a different \
        `placeholder_token` that is not already in the tokenizer."
    )

# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(args_initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
    raise ValueError("The initializer token must be a single token.")

initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(args_placeholder_token)

# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))

# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]

{'clip_sample_range', 'thresholding', 'variance_type', 'prediction_type', 'dynamic_thresholding_ratio', 'sample_max_value'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'upcast_attention', 'time_cond_proj_dim', 'class_embeddings_concat', 'use_linear_projection', 'class_embed_type', 'num_class_embeds', 'time_embedding_type', 'dual_cross_attention', 'conv_in_kernel', 'conv_out_kernel', 'timestep_post_act', 'mid_block_type', 'only_cross_attention', 'projection_class_embeddings_input_dim', 'resnet_time_scale_shift'} was not found in config. Values will be initialized to default values.


In [16]:
# Freeze vae and unet
vae.requires_grad_(False)
unet.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

if args_gradient_checkpointing:
    # Keep unet in train mode if we are using gradient checkpointing to save memory.
    # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
    unet.train()
    text_encoder.gradient_checkpointing_enable()
    unet.enable_gradient_checkpointing()

if args_enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        import xformers

        xformers_version = version.parse(xformers.__version__)
        if xformers_version == version.parse("0.0.16"):
            logger.warn(
                "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe \
                problems during training, please update xFormers to at least 0.0.17. See \
                https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more \
                details."
            )
        unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args_allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

if args_scale_lr:
    args_learning_rate = (
            args_learning_rate * args_gradient_accumulation_steps \
            * args_train_batch_size * accelerator.num_processes
    )

# Initialize the optimizer
optimizer = torch.optim.AdamW(
    text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings
    lr=args_learning_rate,
    betas=(args_adam_beta1, args_adam_beta2),
    weight_decay=args_adam_weight_decay,
    eps=args_adam_epsilon,
)

# Dataset and DataLoaders creation:
train_dataset = TextualInversionDataset(
    data_root=args_train_data_dir,
    tokenizer=tokenizer,
    size=args_resolution,
    placeholder_token=args_placeholder_token,
    repeats=args_repeats,
    learnable_property=args_learnable_property,
    center_crop=args_center_crop,
    set="train",
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args_train_batch_size, shuffle=True, 
    num_workers=args_dataloader_num_workers
)
if args_validation_epochs is not None:
    warnings.warn(
        f"FutureWarning: You are doing logging with validation_epochs={args_validation_epochs}. \
          Deprecated validation_epochs in favor of `validation_steps`"
        f"Setting `args_validation_steps` to {args_validation_epochs * len(train_dataset)}",
        FutureWarning,
        stacklevel=2,
    )
    args_validation_steps = args_validation_epochs * len(train_dataset)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args_gradient_accumulation_steps)
if args_max_train_steps is None:
    args_max_train_steps = args_num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args_lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args_lr_warmup_steps * args_gradient_accumulation_steps,
    num_training_steps=args_max_train_steps * args_gradient_accumulation_steps,
)

# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    text_encoder, optimizer, train_dataloader, lr_scheduler
)

# For mixed precision training we cast the unet and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move vae and unet to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have 
# changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args_gradient_accumulation_steps)
if overrode_max_train_steps:
    args_max_train_steps = args_num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args_num_train_epochs = math.ceil(args_max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    accelerator.init_trackers("textual_inversion", config=vars(args))

In [11]:
# Train!
total_batch_size = args_train_batch_size * accelerator.num_processes \
* args_gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num Epochs = {args_num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args_train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = \
            {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args_gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args_max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args_resume_from_checkpoint:
    if args_resume_from_checkpoint != "latest":
        path = os.path.basename(args_resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args_output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args_resume_from_checkpoint}' does not exist."
            "Starting a new training run."
        )
        args_resume_from_checkpoint = None
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args_output_dir, path))
        global_step = int(path.split("-")[1])

        resume_global_step = global_step * args_gradient_accumulation_steps
        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = resume_global_step % (num_update_steps_per_epoch \
                                            * args_gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args_max_train_steps), 
                    disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(
    text_encoder).get_input_embeddings().weight.data.clone()

for epoch in range(first_epoch, args_num_train_epochs):
    text_encoder.train()
    for step, batch in enumerate(train_dataloader):
        # Skip steps until we reach the resumed step
        if args_resume_from_checkpoint and epoch == first_epoch and step < resume_step:
            if step % args_gradient_accumulation_steps == 0:
                progress_bar.update(1)
            continue

        with accelerator.accumulate(text_encoder):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(
                dtype=weight_dtype)).latent_dist.sample().detach()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), 
                                      device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)

            # Predict the noise residual
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Let's make sure we don't update any embedding weights besides the newly added token
            index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
            with torch.no_grad():
                accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
                    index_no_updates
                ] = orig_embeds_params[index_no_updates]

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            if global_step % args_save_steps == 0:
                save_path = os.path.join(args_output_dir, 
                                         f"learned_embeds-steps-{global_step}.bin")
                save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)

            if global_step % args_checkpointing_steps == 0:
                if accelerator.is_main_process:
                    save_path = os.path.join(args_output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")
            if args_validation_prompt is not None and global_step % args_validation_steps == 0:
                log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, 
                               weight_dtype, epoch)

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

        if global_step >= args_max_train_steps:
            break
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    if args_push_to_hub and args_only_save_embeds:
        logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
        save_full_model = True
    else:
        save_full_model = not args_only_save_embeds
    if save_full_model:
        pipeline = StableDiffusionPipeline.from_pretrained(
            args_pretrained_model_name_or_path,
            text_encoder=accelerator.unwrap_model(text_encoder),
            vae=vae,
            unet=unet,
            tokenizer=tokenizer,
        )
        pipeline.save_pretrained(args_model_dir)
    # Save the newly trained embeddings
    save_path = os.path.join(args_output_dir, "learned_embeds.bin")
    save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)

    if args_push_to_hub:
        repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

accelerator.end_training()