In [None]:
import logging
import math
import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List

import numpy as np
import PIL
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder

from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers import UNet2DConditionModel

from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available

if is_wandb_available():
    import wandb

In [None]:
wandb.login(key = "d9d14819dddd8a35a353b5c0b087e0f60d717140")

In [None]:
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
    PIL_INTERPOLATION = {
        "linear": PIL.Image.Resampling.BILINEAR,
        "bilinear": PIL.Image.Resampling.BILINEAR,
        "bicubic": PIL.Image.Resampling.BICUBIC,
        "lanczos": PIL.Image.Resampling.LANCZOS,
        "nearest": PIL.Image.Resampling.NEAREST,
    }
else:
    PIL_INTERPOLATION = {
        "linear": PIL.Image.LINEAR,
        "bilinear": PIL.Image.BILINEAR,
        "bicubic": PIL.Image.BICUBIC,
        "lanczos": PIL.Image.LANCZOS,
        "nearest": PIL.Image.NEAREST,
    }

check_min_version("0.36.0.dev0")

logger = get_logger(__name__)


In [None]:
@dataclass
class TrainingConfig:
    """Configuration for textual inversion training"""
    # Model settings
    pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
    revision: Optional[str] = None
    variant: Optional[str] = None
    tokenizer_name: Optional[str] = None
    
    # Data settings
    train_data_dir: str = "./data/train"
    placeholder_token: str = "<my-token>"
    initializer_token: str = "a"  # This will be used to initialize the embedding
    learnable_property: str = "object"  # "object" or "style"
    repeats: int = 100
    resolution: int = 512
    center_crop: bool = False
    
    # Training settings
    train_batch_size: int = 4
    num_train_epochs: int = 100
    max_train_steps: int = 5000
    gradient_accumulation_steps: int = 1
    gradient_checkpointing: bool = False
    learning_rate: float = 5e-4
    scale_lr: bool = False
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 500
    lr_num_cycles: int = 1
    dataloader_num_workers: int = 0
    
    # Optimizer settings
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_weight_decay: float = 1e-2
    adam_epsilon: float = 1e-08
    
    # Saving settings
    output_dir: str = "text-inversion-model"
    save_steps: int = 500
    save_as_full_pipeline: bool = False
    checkpointing_steps: int = 500
    checkpoints_total_limit: Optional[int] = None
    resume_from_checkpoint: Optional[str] = None
    no_safe_serialization: bool = False
    
    # Validation settings
    validation_prompt: Optional[str] = None
    num_validation_images: int = 4
    validation_steps: int = 100
    
    # Logging settings
    logging_dir: str = "logs"
    report_to: str = "wandb"
    wandb_project_name: str = "textual-inversion"
    
    # Other settings
    seed: Optional[int] = 42
    mixed_precision: str = "no"  # "no", "fp16", "bf16"
    allow_tf32: bool = False
    enable_xformers_memory_efficient_attention: bool = False
    num_vectors: int = 1
    
    # Hub settings
    push_to_hub: bool = False
    hub_token: Optional[str] = None
    hub_model_id: Optional[str] = None

In [None]:
class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        size=512,
        repeats=100,
        interpolation="bicubic",
        flip_p=0.5,
        set="train",
        placeholder_token="*",
        learnable_property="object",
        center_crop=False,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.size = size
        self.placeholder_token = placeholder_token
        self.learnable_property = learnable_property
        self.center_crop = center_crop
        self.flip_p = flip_p

        # Load images and captions
        self.image_paths = []
        self.captions = []
        
        for file_path in os.listdir(self.data_root):
            if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                img_path = os.path.join(self.data_root, file_path)
                self.image_paths.append(img_path)
                
                # Look for caption file (same name with .txt extension)
                caption_path = os.path.splitext(img_path)[0] + '.txt'
                if os.path.exists(caption_path):
                    with open(caption_path, 'r') as f:
                        caption = f.read().strip()
                else:
                    caption = ""  # Default empty caption
                self.captions.append(caption)

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        if set == "train":
            self._length = self.num_images * repeats

        self.interpolation = {
            "linear": PIL_INTERPOLATION["linear"],
            "bilinear": PIL_INTERPOLATION["bilinear"],
            "bicubic": PIL_INTERPOLATION["bicubic"],
            "lanczos": PIL_INTERPOLATION["lanczos"],
        }[interpolation]

        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        idx = i % self.num_images
        image = Image.open(self.image_paths[idx])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        # Use template with caption
        caption = self.captions[idx]
        if self.learnable_property == "object":
            if caption:
                text = f"A photo of {self.placeholder_token}, {caption}"
            else:
                text = f"A photo of {self.placeholder_token}"
        else:
            if caption:
                text = f"A photo in the style of {self.placeholder_token}, {caption}"
            else:
                text = f"A photo in the style of {self.placeholder_token}"

        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            h, w = img.shape[0], img.shape[1]
            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

In [None]:

def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):
    img_str = ""
    if images is not None:
        for i, image in enumerate(images):
            image.save(os.path.join(repo_folder, f"image_{i}.png"))
            img_str += f"![img_{i}](./image_{i}.png)\n"
    model_description = f"""
# Textual inversion text2image fine-tuning - {repo_id}
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
{img_str}
"""
    model_card = load_or_create_model_card(
        repo_id_or_path=repo_id,
        from_training=True,
        license="creativeml-openrail-m",
        base_model=base_model,
        model_description=model_description,
        inference=True,
    )

    tags = [
        "stable-diffusion",
        "stable-diffusion-diffusers",
        "text-to-image",
        "diffusers",
        "textual_inversion",
        "diffusers-training",
    ]
    model_card = populate_model_card(model_card, tags=tags)

    model_card.save(os.path.join(repo_folder, "README.md"))


def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
        f" {args.validation_prompt}."
    )
    # create pipeline (note: unet and vae are loaded again in float32)
    pipeline = DiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        text_encoder=accelerator.unwrap_model(text_encoder),
        tokenizer=tokenizer,
        unet=unet,
        vae=vae,
        safety_checker=None,
        revision=args.revision,
        variant=args.variant,
        torch_dtype=weight_dtype,
    )
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=True)

    # run inference
    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
    images = []
    for _ in range(args.num_validation_images):
        if torch.backends.mps.is_available():
            autocast_ctx = nullcontext()
        else:
            autocast_ctx = torch.autocast(accelerator.device.type)

        with autocast_ctx:
            image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
        images.append(image)

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
        if tracker.name == "wandb":
            tracker.log(
                {
                    "validation": [
                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
                    ]
                }
            )

    del pipeline
    torch.cuda.empty_cache()
    return images


def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
    logger.info("Saving embeddings")
    learned_embeds = (
        accelerator.unwrap_model(text_encoder)
        .get_input_embeddings()
        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
    )
    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}

    if safe_serialization:
        safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
    else:
        torch.save(learned_embeds_dict, save_path)

In [None]:
class TextualInversionTrainer:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.accelerator = None
        self.tokenizer = None
        self.text_encoder = None
        self.vae = None
        self.unet = None
        self.noise_scheduler = None
        self.optimizer = None
        self.lr_scheduler = None
        self.train_dataloader = None
        self.placeholder_token_ids = None
        
    def setup(self):
        """Initialize all components"""
        args = self.config
        
        # Setup logging
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO,
        )
        
        # Setup accelerator
        logging_dir = os.path.join(args.output_dir, args.logging_dir)
        accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
        self.accelerator = Accelerator(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision=args.mixed_precision,
            log_with=args.report_to,
            project_config=accelerator_project_config,
        )

        # Disable AMP for MPS.
        if torch.backends.mps.is_available():
            self.accelerator.native_amp = False

        logger.info(self.accelerator.state, main_process_only=False)
        if self.accelerator.is_local_main_process:
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        # Set seed
        if args.seed is not None:
            set_seed(args.seed)

        # Create output directory
        if self.accelerator.is_main_process:
            if args.output_dir is not None:
                os.makedirs(args.output_dir, exist_ok=True)

            if args.push_to_hub:
                self.repo_id = create_repo(
                    repo_id=args.hub_model_id or Path(args.output_dir).name, 
                    exist_ok=True, 
                    token=args.hub_token
                ).repo_id

        # Load tokenizer
        if args.tokenizer_name:
            self.tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
        elif args.pretrained_model_name_or_path:
            self.tokenizer = CLIPTokenizer.from_pretrained(
                args.pretrained_model_name_or_path, subfolder="tokenizer"
            )

        # Load models
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="scheduler"
        )
        self.text_encoder = CLIPTextModel.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
        )
        self.vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
        )
        self.unet = UNet2DConditionModel.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
        )

        # Add placeholder tokens
        placeholder_tokens = [args.placeholder_token]
        if args.num_vectors < 1:
            raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")

        additional_tokens = []
        for i in range(1, args.num_vectors):
            additional_tokens.append(f"{args.placeholder_token}_{i}")
        placeholder_tokens += additional_tokens

        num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens)
        if num_added_tokens != args.num_vectors:
            raise ValueError(
                f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
                " `placeholder_token` that is not already in the tokenizer."
            )

        # Get initializer token ID and initialize placeholder embeddings
        token_ids = self.tokenizer.encode(args.initializer_token, add_special_tokens=False)
        if len(token_ids) > 1:
            raise ValueError("The initializer token must be a single token.")

        initializer_token_id = token_ids[0]
        self.placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)

        # Resize token embeddings
        self.text_encoder.resize_token_embeddings(len(self.tokenizer))

        # Initialize placeholder token embeddings with initializer token embeddings
        token_embeds = self.text_encoder.get_input_embeddings().weight.data
        with torch.no_grad():
            for token_id in self.placeholder_token_ids:
                token_embeds[token_id] = token_embeds[initializer_token_id].clone()

        # Freeze models
        self.vae.requires_grad_(False)
        self.unet.requires_grad_(False)
        self.text_encoder.text_model.encoder.requires_grad_(False)
        self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
        self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

        if args.gradient_checkpointing:
            self.unet.train()
            self.text_encoder.gradient_checkpointing_enable()
            self.unet.enable_gradient_checkpointing()

        if args.enable_xformers_memory_efficient_attention:
            if is_xformers_available():
                import xformers
                xformers_version = version.parse(xformers.__version__)
                if xformers_version == version.parse("0.0.16"):
                    logger.warning(
                        "xFormers 0.0.16 cannot be used for training in some GPUs. "
                        "If you observe problems during training, please update xFormers to at least 0.0.17."
                    )
                self.unet.enable_xformers_memory_efficient_attention()
            else:
                raise ValueError("xformers is not available. Make sure it is installed correctly")

        if args.allow_tf32:
            torch.backends.cuda.matmul.allow_tf32 = True

        if args.scale_lr:
            args.learning_rate = (
                args.learning_rate * args.gradient_accumulation_steps * 
                args.train_batch_size * self.accelerator.num_processes
            )

        # Initialize optimizer
        self.optimizer = torch.optim.AdamW(
            self.text_encoder.get_input_embeddings().parameters(),
            lr=args.learning_rate,
            betas=(args.adam_beta1, args.adam_beta2),
            weight_decay=args.adam_weight_decay,
            eps=args.adam_epsilon,
        )

        # Create dataset
        train_dataset = TextualInversionDataset(
            data_root=args.train_data_dir,
            tokenizer=self.tokenizer,
            size=args.resolution,
            placeholder_token=(" ".join(self.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))),
            repeats=args.repeats,
            center_crop=args.center_crop,
            set="train",
        )
        
        self.train_dataloader = torch.utils.data.DataLoader(
            train_dataset, 
            batch_size=args.train_batch_size, 
            shuffle=True, 
            num_workers=args.dataloader_num_workers
        )

        # Calculate training steps
        overrode_max_train_steps = False
        num_update_steps_per_epoch = math.ceil(
            len(self.train_dataloader) / args.gradient_accumulation_steps
        )
        if args.max_train_steps is None:
            args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
            overrode_max_train_steps = True

        # Create scheduler
        self.lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=self.optimizer,
            num_warmup_steps=args.lr_warmup_steps * self.accelerator.num_processes,
            num_training_steps=args.max_train_steps * self.accelerator.num_processes,
            num_cycles=args.lr_num_cycles,
        )

        # Prepare with accelerator
        self.text_encoder.train()
        (
            self.text_encoder, 
            self.optimizer, 
            self.train_dataloader, 
            self.lr_scheduler
        ) = self.accelerator.prepare(
            self.text_encoder, self.optimizer, self.train_dataloader, self.lr_scheduler
        )

        # Set weight dtype
        weight_dtype = torch.float32
        if self.accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif self.accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16

        self.weight_dtype = weight_dtype
        self.unet.to(self.accelerator.device, dtype=weight_dtype)
        self.vae.to(self.accelerator.device, dtype=weight_dtype)

        # Recalculate training steps
        num_update_steps_per_epoch = math.ceil(
            len(self.train_dataloader) / args.gradient_accumulation_steps
        )
        if overrode_max_train_steps:
            args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

        # Initialize trackers
        if self.accelerator.is_main_process:
            self.accelerator.init_trackers(args.wandb_project_name, config=vars(args))

        return num_update_steps_per_epoch

    def train(self):
        """Main training loop"""
        args = self.config
        num_update_steps_per_epoch = self.setup()

        total_batch_size = (
            args.train_batch_size * 
            self.accelerator.num_processes * 
            args.gradient_accumulation_steps
        )

        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {len(self.train_dataloader.dataset)}")
        logger.info(f"  Num Epochs = {args.num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {args.max_train_steps}")

        global_step = 0
        first_epoch = 0
        initial_global_step = 0

        # Resume from checkpoint if needed
        if args.resume_from_checkpoint:
            if args.resume_from_checkpoint != "latest":
                path = os.path.basename(args.resume_from_checkpoint)
            else:
                dirs = os.listdir(args.output_dir)
                dirs = [d for d in dirs if d.startswith("checkpoint")]
                dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
                path = dirs[-1] if len(dirs) > 0 else None

            if path is None:
                self.accelerator.print(
                    f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
                )
                args.resume_from_checkpoint = None
                initial_global_step = 0
            else:
                self.accelerator.print(f"Resuming from checkpoint {path}")
                self.accelerator.load_state(os.path.join(args.output_dir, path))
                global_step = int(path.split("-")[1])
                initial_global_step = global_step
                first_epoch = global_step // num_update_steps_per_epoch

        progress_bar = tqdm(
            range(0, args.max_train_steps),
            initial=initial_global_step,
            desc="Steps",
            disable=not self.accelerator.is_local_main_process,
        )

        # Keep original embeddings as reference
        orig_embeds_params = (
            self.accelerator.unwrap_model(self.text_encoder)
            .get_input_embeddings()
            .weight.data.clone()
        )

        # Training loop
        for epoch in range(first_epoch, args.num_train_epochs):
            self.text_encoder.train()
            train_loss = 0.0
            
            for step, batch in enumerate(self.train_dataloader):
                with self.accelerator.accumulate(self.text_encoder):
                    # Convert images to latent space
                    latents = self.vae.encode(
                        batch["pixel_values"].to(dtype=self.weight_dtype)
                    ).latent_dist.sample().detach()
                    latents = latents * self.vae.config.scaling_factor

                    # Sample noise
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    timesteps = torch.randint(
                        0, self.noise_scheduler.config.num_train_timesteps, 
                        (bsz,), device=latents.device
                    )
                    timesteps = timesteps.long()

                    # Add noise to latents
                    noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get text embeddings
                    encoder_hidden_states = self.text_encoder(batch["input_ids"])[0].to(
                        dtype=self.weight_dtype
                    )

                    # Predict noise
                    model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample

                    # Get target
                    if self.noise_scheduler.config.prediction_type == "epsilon":
                        target = noise
                    elif self.noise_scheduler.config.prediction_type == "v_prediction":
                        target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(
                            f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
                        )

                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                    self.accelerator.backward(loss)
                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

                    # Reset embeddings for non-placeholder tokens
                    index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
                    index_no_updates[
                        min(self.placeholder_token_ids) : max(self.placeholder_token_ids) + 1
                    ] = False

                    with torch.no_grad():
                        self.accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
                            index_no_updates
                        ] = orig_embeds_params[index_no_updates]

                # Accumulate loss for logging
                train_loss += loss.detach().item()

                # Check if optimization step occurred
                if self.accelerator.sync_gradients:
                    images = []
                    progress_bar.update(1)
                    global_step += 1

                    # Save embeddings
                    if global_step % args.save_steps == 0:
                        weight_name = (
                            f"learned_embeds-steps-{global_step}.bin"
                            if args.no_safe_serialization
                            else f"learned_embeds-steps-{global_step}.safetensors"
                        )
                        save_path = os.path.join(args.output_dir, weight_name)
                        save_progress(
                            self.text_encoder,
                            self.placeholder_token_ids,
                            self.accelerator,
                            args,
                            save_path,
                            safe_serialization=not args.no_safe_serialization,
                        )

                    # Save checkpoint
                    if self.accelerator.is_main_process:
                        if global_step % args.checkpointing_steps == 0:
                            if args.checkpoints_total_limit is not None:
                                checkpoints = os.listdir(args.output_dir)
                                checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                                checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                                if len(checkpoints) >= args.checkpoints_total_limit:
                                    num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                    removing_checkpoints = checkpoints[0:num_to_remove]

                                    logger.info(
                                        f"{len(checkpoints)} checkpoints already exist, "
                                        f"removing {len(removing_checkpoints)} checkpoints"
                                    )

                                    for removing_checkpoint in removing_checkpoints:
                                        removing_checkpoint = os.path.join(
                                            args.output_dir, removing_checkpoint
                                        )
                                        shutil.rmtree(removing_checkpoint)

                            save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                            self.accelerator.save_state(save_path)
                            logger.info(f"Saved state to {save_path}")

                        # Run validation
                        if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                            images = log_validation(
                                self.text_encoder,
                                self.tokenizer,
                                self.unet,
                                self.vae,
                                args,
                                self.accelerator,
                                self.weight_dtype,
                                epoch,
                            )

                    # Log metrics
                    avg_train_loss = train_loss / args.gradient_accumulation_steps
                    logs = {
                        "train_loss": avg_train_loss,
                        "lr": self.lr_scheduler.get_last_lr()[0],
                        "epoch": epoch,
                    }
                    progress_bar.set_postfix(**logs)
                    self.accelerator.log(logs, step=global_step)
                    train_loss = 0.0

                    if global_step >= args.max_train_steps:
                        break

            if global_step >= args.max_train_steps:
                break

        # Save final model
        self.accelerator.wait_for_everyone()
        if self.accelerator.is_main_process:
            if args.push_to_hub and not args.save_as_full_pipeline:
                logger.warning("Enabling full model saving because --push_to_hub=True was specified.")
                save_full_model = True
            else:
                save_full_model = args.save_as_full_pipeline
                
            if save_full_model:
                pipeline = StableDiffusionPipeline.from_pretrained(
                    args.pretrained_model_name_or_path,
                    text_encoder=self.accelerator.unwrap_model(self.text_encoder),
                    vae=self.vae,
                    unet=self.unet,
                    tokenizer=self.tokenizer,
                )
                pipeline.save_pretrained(args.output_dir)
                
            # Save final embeddings
            weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
            save_path = os.path.join(args.output_dir, weight_name)
            save_progress(
                self.text_encoder,
                self.placeholder_token_ids,
                self.accelerator,
                args,
                save_path,
                safe_serialization=not args.no_safe_serialization,
            )

            if args.push_to_hub:
                save_model_card(
                    self.repo_id,
                    images=images,
                    base_model=args.pretrained_model_name_or_path,
                    repo_folder=args.output_dir,
                )
                upload_folder(
                    repo_id=self.repo_id,
                    folder_path=args.output_dir,
                    commit_message="End of training",
                    ignore_patterns=["step_*", "epoch_*"],
                )

        self.accelerator.end_training()
        logger.info("Training completed!")

In [None]:
def train_textual_inversion(
    train_data_dir: str,
    placeholder_token: str = "<my-token>",
    initializer_token: str = "a",
    learnable_property: str = "object",
    output_dir: str = "textual-inversion-output",
    pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5",
    resolution: int = 512,
    train_batch_size: int = 4,
    learning_rate: float = 5e-4,
    max_train_steps: int = 3000,
    save_steps: int = 500,
    validation_prompt: str = None,
    validation_steps: int = 100,
    num_validation_images: int = 4,
    wandb_project_name: str = "textual-inversion",
    seed: int = 42,
    mixed_precision: str = "no",
    gradient_accumulation_steps: int = 1,
    num_vectors: int = 1,
    repeats: int = 100,
    center_crop: bool = False,
    lr_scheduler: str = "constant",
    lr_warmup_steps: int = 500,
    checkpointing_steps: int = 500,
    resume_from_checkpoint: str = None,
):
    """
    Train a textual inversion model.
    
    Args:
        train_data_dir: Path to training data directory containing images and .txt caption files
        placeholder_token: Token to learn (e.g., "<my-cat>")
        initializer_token: Token to initialize embedding (e.g., "cat", "dog", "object", "style")
        output_dir: Where to save the trained embeddings
        pretrained_model_name_or_path: Base Stable Diffusion model
        resolution: Training image resolution
        train_batch_size: Batch size per device
        learning_rate: Learning rate
        max_train_steps: Maximum training steps
        save_steps: Save embeddings every N steps
        validation_prompt: Prompt for validation (should include placeholder_token)
        validation_steps: Run validation every N steps
        num_validation_images: Number of validation images to generate
        wandb_project_name: Weights & Biases project name
        seed: Random seed
        mixed_precision: "no", "fp16", or "bf16"
        gradient_accumulation_steps: Gradient accumulation steps
        num_vectors: Number of vectors to learn
        repeats: How many times to repeat the training data
        center_crop: Whether to center crop images
        lr_scheduler: Learning rate scheduler type
        lr_warmup_steps: Warmup steps for learning rate
        checkpointing_steps: Save checkpoint every N steps
        resume_from_checkpoint: Path to checkpoint to resume from
    """
    
    config = TrainingConfig(
        train_data_dir=train_data_dir,
        placeholder_token=placeholder_token,
        initializer_token=initializer_token,
        learnable_property=learnable_property,
        output_dir=output_dir,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        resolution=resolution,
        train_batch_size=train_batch_size,
        learning_rate=learning_rate,
        max_train_steps=max_train_steps,
        save_steps=save_steps,
        validation_prompt=validation_prompt,
        validation_steps=validation_steps,
        num_validation_images=num_validation_images,
        wandb_project_name=wandb_project_name,
        seed=seed,
        mixed_precision=mixed_precision,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_vectors=num_vectors,
        repeats=repeats,
        center_crop=center_crop,
        lr_scheduler=lr_scheduler,
        lr_warmup_steps=lr_warmup_steps,
        checkpointing_steps=checkpointing_steps,
        resume_from_checkpoint=resume_from_checkpoint,
        report_to="wandb",
    )
    
    trainer = TextualInversionTrainer(config)
    trainer.train()
    
    return trainer


In [None]:
# Train the model
trainer = train_textual_inversion(
    train_data_dir="/kaggle/input/bovagau-poses/images/Bo",
    num_vectors=6,
    learnable_property="object",
    placeholder_token="<Bo>",
    initializer_token="anthro",
    resolution=512,
    train_batch_size=2,
    gradient_accumulation_steps=4,
    max_train_steps=5000,
    learning_rate=5.0e-04,
    lr_scheduler="constant",
    output_dir="/kaggle/working/bo_text_inver_output",
    mixed_precision="fp16",
    checkpointing_steps=500,
    validation_prompt="a photo of <Bo> in the garden",
    num_validation_images=4,
    validation_steps=500,
    seed=36,
    wandb_project_name="textual-inversion",
)

In [None]:
"""
# Install required packages first:
# !pip install diffusers transformers accelerate safetensors wandb

# Login to wandb
import wandb
wandb.login()

# Prepare your data structure:
# data/
#   ├── image1.jpg
#   ├── image1.txt  (contains caption like "wearing sunglasses")
#   ├── image2.jpg
#   ├── image2.txt  (contains caption like "sitting on a chair")
#   └── ...

# Train the model
trainer = train_textual_inversion(
    train_data_dir="./data",
    placeholder_token="<my-cat>",
    initializer_token="cat",
    output_dir="./output/my-cat",
    validation_prompt="A photo of <my-cat> in a garden",
    max_train_steps=3000,
    learning_rate=5e-4,
    train_batch_size=4,
    validation_steps=250,
    save_steps=250,
    wandb_project_name="my-textual-inversion",
)

# After training, load and use the embeddings:
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to("cuda")

# Load the learned embeddings
pipe.load_textual_inversion("./output/my-cat/learned_embeds.safetensors")

# Generate images
image = pipe("A photo of <my-cat> playing with a ball", num_inference_steps=50).images[0]
image.save("output.png")
"""