# Download data

In [None]:
!gdown 1muc9TPG3CJJ1xL7Gkryd7dnJkuZBpWvZ # Trainning data
!gdown 1PpKJpP35ESzi5En9Mvro5DuEy2QR5VK_ # Public testing data

In [None]:
!mkdir data
!unzip train.zip -d data
!unzip test.zip -d data

In [None]:
!pip install googletrans==3.1.0a0
import googletrans
import pandas as pd
from tqdm import tqdm

class Translation():
    def __init__(self, from_lang='vi', to_lang='en'):
        # The class Translation is a wrapper for the two translation libraries, googletrans and translate.
        self.__to_lang = to_lang
        self.translator = googletrans.Translator()

    def preprocessing(self, text):
        """
        It takes a string as input, and returns a string with all the letters in lowercase

        :param text: The text to be processed
        :return: The text is being returned in lowercase.
        """
        return text.lower()

    def __call__(self, text):
        """
        The function takes in a text and preprocesses it before translation

        :param text: The text to be translated
        :return: The translated text.
        """
        text = self.preprocessing(text)
        return self.translator.translate(text, dest=self.__to_lang).text

if __name__ == '__main__':
    translator = Translation()

    train_data = pd.read_csv("/content/data/train/info_trans.csv")
    test_data = pd.read_csv("/content/data/test/info_trains.csv")

    for i in tqdm(range(len(train_data))):
        train_data.loc[i, "caption"] = translator(str(train_data.loc[i, "caption"]))
        train_data.loc[i, "description"] = translator(str(train_data.loc[i, "description"]))
        train_data.loc[i, "moreInfo"] = translator(str(train_data.loc[i, "moreInfo"]))

    for i in tqdm(range(len(test_data))):
        test_data.loc[i, "caption"] = translator(str(test_data.loc[i, "caption"]))
        test_data.loc[i, "description"] = translator(str(test_data.loc[i, "description"]))
        test_data.loc[i, "moreInfo"] = translator(str(test_data.loc[i, "moreInfo"]))

    train_data.to_csv('/content/data/train/info.csv')
    test_data.to_csv('/content/data/test/info.csv')

# Data structure

```
|- data/
    |- train/
        |- images/
        |- info.csv
    |- test/
        |- info.csv
```

# Install dependency

In [None]:
!pip install diffusers
!pip install transformers
!pip install wandb
!pip install yacs
!pip install accelerate

Collecting diffusers
  Downloading diffusers-0.26.3-py3-none-any.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: diffusers
Successfully installed diffusers-0.26.3
Collecting wandb
  Downloading wandb-0.16.4-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.41.0-py2.py3-none-any.whl (258 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.8/258.8 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycr

# Import Library

In [None]:
import os
import torch
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from yacs.config import CfgNode as CN

import wandb
import argparse
import logging
import math
import os
import shutil
from pathlib import Path

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torchvision import transforms

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from tqdm.auto import tqdm
import transformers
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils.import_utils import is_xformers_available

In [None]:
logger = get_logger(__name__, log_level="INFO")

# Dataloader

In [None]:
def default_loader(path):
    return Image.open(path).convert('RGB')

def tokenize_caption(caption, tokenizer):
    inputs = tokenizer(
        caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

def train_collate_fn(samples):
    pixel_values = torch.stack([sample["pixel_values"] for sample in samples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.cat([sample["input_ids"] for sample in samples], dim=0)
    return {"pixel_values": pixel_values, "input_ids": input_ids}

def test_collate_fn(samples):
    captions = []
    paths = []
    for sample in samples:
        captions.append(sample["captions"])
        paths.append(sample["paths"])
    return {"captions": captions,
            "paths": paths}

class BannerDataset(Dataset):
    def __init__(self, data_cfg, tokenizer, transform=None, mode='train') -> None:
        super().__init__()
        assert (mode in ["train", "test"]), "Please specify correct data mode !"
        self.data_cfg = data_cfg
        self.transform = transform
        self.tokenizer = tokenizer
        self.mode = mode
        self.data_dir = data_cfg.DATA_DIR
        self.data_csv_path = data_cfg.TRAIN_CSV_PATH if mode == "train" else data_cfg.TEST_CSV_PATH
        self.data_csv_path = os.path.join(self.data_dir, self.data_csv_path)
        self.data = pd.read_csv(self.data_csv_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        '''
        Data row format:
        id | caption | description |   moreInfo  | bannerImage (path)
        3  | Áo .... | Mua ngay .. | Miễn phí ...| 3.jpg
        '''
        sample = self.data.iloc[index]

        # Load caption
        caption = sample["caption"]

        if self.mode == "train":
            # Load image
            image = default_loader(os.path.join(self.data_dir, self.mode, "images/", sample["bannerImage"]))
            if self.transform is not None:
                image = self.transform(image)

            caption_ids = tokenize_caption(caption, self.tokenizer)

            return {"pixel_values": image,
                    "input_ids": caption_ids}
        else:
            return {"captions": caption,
                    "paths": sample["bannerImage"]}

In [None]:
def build_dataloader(cfg, tokenizer):
    train_transform = transforms.Compose(
        [
            transforms.Resize(cfg.DATA.RESOLUTION, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(cfg.DATA.RESOLUTION) if cfg.DATA.CENTER_CROP else transforms.RandomCrop(cfg.DATA.RESOLUTION),
            transforms.RandomHorizontalFlip() if cfg.DATA.RANDOM_FLIP else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    train_dataset = BannerDataset(cfg.DATA, tokenizer, transform=train_transform, mode="train")
    test_dataset = BannerDataset(cfg.DATA, tokenizer, transform=None, mode="test")
    val_dataset = torch.utils.data.Subset(test_dataset, list(range(0, 10)))

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=train_collate_fn,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.TRAIN.NUM_WORKERS
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        shuffle=False,
        collate_fn=test_collate_fn,
        batch_size=1,
        num_workers=cfg.TRAIN.NUM_WORKERS
    )

    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        shuffle=False,
        collate_fn=test_collate_fn,
        batch_size=1,
        num_workers=cfg.TRAIN.NUM_WORKERS
    )

    return train_dataloader, test_dataloader, val_dataloader

# Config file

In [None]:
_C = CN()

dataset_path = "data/"

# DATA
_C.DATA = CN()
_C.DATA.DATA_DIR = dataset_path
_C.DATA.TRAIN_CSV_PATH = 'train/info.csv'
_C.DATA.TEST_CSV_PATH = 'test/info.csv'
_C.DATA.RESOLUTION = 512 # Các bạn truyền vào đây nha ^^
_C.DATA.CENTER_CROP = False
_C.DATA.RANDOM_FLIP = False

# Model specific configurations.
_C.MODEL = CN()
_C.MODEL.NAME = 'stabilityai/stable-diffusion-2-1'
_C.MODEL.XFORMERS = False # Whether or not to use xformers for memory efficient.
_C.MODEL.NOISE_OFFSET = 0 # https://www.crosslabs.org//blog/diffusion-with-offset-noise
_C.MODEL.RANK = 4 # Lora rank.

# Training configurations
_C.TRAIN = CN()
_C.TRAIN.SEED = 1337
_C.TRAIN.EPOCH = 1 # Các bạn truyền vào đây nha ^^
_C.TRAIN.BATCH_SIZE = 1 # Các bạn truyền vào đây nha ^^
_C.TRAIN.NUM_WORKERS = 1 # Các bạn truyền vào đây nha ^^
_C.TRAIN.MAX_NORM = 1.0 # Các bạn truyền vào đây nha ^^
_C.TRAIN.GRADIENT_ACCUMULATION_STEP = 4

## Learning rate setting
_C.TRAIN.LR = CN()
# Choose between ["linear", "cosine", "cosine_with_restarts",
# "polynomial", "constant", "constant_with_warmup"]
_C.TRAIN.LR.MODE = "cosine" # Các bạn truyền vào đây nha ^^
_C.TRAIN.LR.BASE_LR = 1e-4 # Các bạn truyền vào đây nha ^^
_C.TRAIN.LR.WARMUP_EPOCH = 0 # Các bạn truyền vào đây nha ^^
# Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
_C.TRAIN.LR.SCALE_LR = False # True/False # Các bạn truyền vào đây nha ^^

## Optimizer setting
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 1e-2
_C.TRAIN.OPTIMIZER.EPSILON = 1e-08

# Validation configurations
_C.EVAL = CN()
_C.EVAL.EPOCH = 1  # Các bạn truyền vào đây nha ^^

# Testing configurations
_C.TEST = CN()
_C.TEST.RESTORE_FROM = ""

def get_default_config():
    return _C.clone()

# Trainning script

## Set up accelerator

In [None]:
output_dir = "checkpoints/"
logging_dir = "logs/"
report_to = "wandb" # Logging to wandb
mixed_precision = "fp16" # Mixed precision trainning

# "Whether training should be resumed from a previous checkpoint.
# Use a path saved by" --checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
resume_from_checkpoint = None

# The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`.
# If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.
prediction_type = None

# SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0.
# More details here: https://arxiv.org/abs/2303.09556.
snr_gamma = None

checkpoints_total_limit = 5 # Max number of checkpoints to store.
checkpointing_steps = 500 # Save a checkpoint of the training state every X updates.

In [None]:
cfg = get_default_config()
logging_dir = Path(output_dir, logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=cfg.TRAIN.GRADIENT_ACCUMULATION_STEP,
    mixed_precision=mixed_precision,
    log_with=report_to,
    project_config=accelerator_project_config,
)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if cfg.TRAIN.SEED is not None:
    set_seed(cfg.TRAIN.SEED)

# Handle the repository creation
if accelerator.is_main_process:
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)

## Set up model

In [None]:
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(cfg.MODEL.NAME, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    cfg.MODEL.NAME, subfolder="tokenizer", revision=None
)
text_encoder = CLIPTextModel.from_pretrained(
    cfg.MODEL.NAME, subfolder="text_encoder", revision=None
)
vae = AutoencoderKL.from_pretrained(cfg.MODEL.NAME, subfolder="vae", revision=None)
unet = UNet2DConditionModel.from_pretrained(
    cfg.MODEL.NAME, subfolder="unet", revision=None
)

# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    lora_attn_procs[name] = LoRAAttnProcessor(
        hidden_size=hidden_size,
        cross_attention_dim=cross_attention_dim,
        rank=cfg.MODEL.RANK,
    )

unet.set_attn_processor(lora_attn_procs)

lora_layers = AttnProcsLayers(unet.attn_processors)

if cfg.TRAIN.LR.SCALE_LR:
    cfg.TRAIN.LR.BASE_LR = (
        cfg.TRAIN.LR.BASE_LR * cfg.TRAIN.GRADIENT_ACCUMULATION_STEP * cfg.TRAIN.BATCH_SIZE * accelerator.num_processes
    )

## Set up optimzer, dataloader and scheduler

In [None]:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
    lora_layers.parameters(),
    lr=cfg.TRAIN.LR.BASE_LR,
    betas=cfg.TRAIN.OPTIMIZER.BETAS,
    weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
    eps=cfg.TRAIN.OPTIMIZER.EPSILON,
)

# DataLoaders creation:
train_dataloader, test_dataloader, val_dataloader = build_dataloader(cfg, tokenizer)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.TRAIN.GRADIENT_ACCUMULATION_STEP)
max_train_steps = cfg.TRAIN.EPOCH * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    cfg.TRAIN.LR.MODE,
    optimizer=optimizer,
    num_warmup_steps=cfg.TRAIN.LR.WARMUP_EPOCH * num_update_steps_per_epoch * accelerator.num_processes,
    num_training_steps=max_train_steps * accelerator.num_processes,
)

# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, test_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    lora_layers, optimizer, train_dataloader, test_dataloader, val_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.TRAIN.GRADIENT_ACCUMULATION_STEP)
max_train_steps = cfg.TRAIN.EPOCH * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
cfg.TRAIN.EPOCH = math.ceil(max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    accelerator.init_trackers("text2image-fine-tune", config=vars(cfg))

## Start trainning !!

In [None]:
# Train!
total_batch_size = cfg.TRAIN.BATCH_SIZE * accelerator.num_processes * cfg.TRAIN.GRADIENT_ACCUMULATION_STEP

logger.info("***** Running training *****")
logger.info(f"  Num train samples = {len(train_dataloader)}")
logger.info(f"  Num Epochs = {cfg.TRAIN.EPOCH}")
logger.info(f"  Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {cfg.TRAIN.GRADIENT_ACCUMULATION_STEP}")
logger.info(f"  Total optimization steps = {max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if resume_from_checkpoint:
    if resume_from_checkpoint != "latest":
        path = os.path.basename(resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(output_dir, path))
        global_step = int(path.split("-")[1])

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:
    initial_global_step = 0

progress_bar = tqdm(
    range(0, max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, cfg.TRAIN.EPOCH):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            if cfg.MODEL.NOISE_OFFSET:
                # https://www.crosslabs.org//blog/diffusion-with-offset-noise
                noise += cfg.MODEL.NOISE_OFFSET * torch.randn(
                    (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
                )

            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Get the target for loss depending on the prediction type
            if prediction_type is not None:
                # set prediction_type of scheduler if defined
                noise_scheduler.register_to_config(prediction_type=prediction_type)

            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            if snr_gamma is None:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            else:
                # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                # This is discussed in Section 4.2 of the same paper.
                snr = compute_snr(noise_scheduler, timesteps)
                if noise_scheduler.config.prediction_type == "v_prediction":
                    # Velocity objective requires that we add one to SNR values before we divide by them.
                    snr = snr + 1
                mse_loss_weights = (
                    torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                )

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                loss = loss.mean()

            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(cfg.TRAIN.BATCH_SIZE)).mean()
            train_loss += avg_loss.item() / cfg.TRAIN.GRADIENT_ACCUMULATION_STEP

            # Backpropagate
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = lora_layers.parameters()
                accelerator.clip_grad_norm_(params_to_clip, cfg.TRAIN.MAX_NORM)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            accelerator.log({"train_loss": train_loss}, step=global_step)
            train_loss = 0.0

            if global_step % checkpointing_steps == 0:
                if accelerator.is_main_process:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if checkpoints_total_limit is not None:
                        checkpoints = os.listdir(output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= checkpoints_total_limit:
                            num_to_remove = len(checkpoints) - checkpoints_total_limit + 1
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")

        logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

        if global_step >= max_train_steps:
            break

    if accelerator.is_main_process:
        logger.info(
            f"Running validation..."
        )
        # create pipeline
        pipeline = DiffusionPipeline.from_pretrained(
            cfg.MODEL.NAME,
            unet=accelerator.unwrap_model(unet),
            revision=None,
            torch_dtype=weight_dtype,
        )
        pipeline = pipeline.to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)

        # run inference
        generator = torch.Generator(device=accelerator.device)
        if cfg.TRAIN.SEED is not None:
            generator = generator.manual_seed(cfg.TRAIN.SEED)

        images = []
        captions = []
        for sample in val_dataloader:
            images.append(
                pipeline(sample["captions"][0], num_inference_steps=30, generator=generator).images[0]
            )
            captions.append(sample["captions"][0])

        for tracker in accelerator.trackers:
            if tracker.name == "wandb":
                tracker.log(
                    {
                        "validation": [
                            wandb.Image(image, caption)
                            for i, (image, caption) in enumerate(zip(images, captions))
                        ]
                    }
                )
        del pipeline
        torch.cuda.empty_cache()

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    unet = unet.to(torch.float32)
    unet.save_attn_procs(output_dir)

    save_path = os.path.join(output_dir, f"final-checkpoint")
    accelerator.save_state(save_path)
    logger.info(f"Saved model to {save_path}")

accelerator.end_training()