In [1]:
import os
import shutil
from pathlib import Path

from dataclasses import dataclass

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from datasets import Dataset

from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    UNet2DConditionModel,
)
from diffusers.training_utils import EMAModel

In [2]:
@dataclass
class TrainingArgs:
    image_size = 512  # the generated image resolution
    train_batch_size = 1
    num_train_epochs = 1000
    max_train_steps = 1000
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    mixed_precision = "fp16"
    weight_dtype = torch.float16
    output_dir = "ddpm-pokemon-64"
    data_dir = Path("../dataset/pokemon/Generation_1") 
    
    pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1"

    use_ema = False
    guidance_scale = 7.5
    adam_beta1 = 0.9
    adam_beta2 = 0.999
    adam_weight_decay = 1e-2
    adam_epsilon = 1e-8
    max_grad_norm = 1

    resume_from_checkpoint = "latest"
    validation_prompts = ["a pokemon with a fire tail"]
    validation_epochs = 100
    logging_dir = "ddpm-pokemon-64/logs"
    report_to = "tensorboard"
    seed = 0

args = TrainingArgs()

In [3]:
logger = get_logger(__name__, log_level="INFO")
accelerator_project_config = ProjectConfiguration(
    project_dir=args.output_dir, logging_dir=args.logging_dir
)

accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_config=accelerator_project_config,
)

In [4]:
def save_model_hook(models, weights, output_dir):
    if args.use_ema:
        ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))

    for i, model in enumerate(models):
        model.save_pretrained(os.path.join(output_dir, "unet"))

        # make sure to pop weight so that corresponding model is not saved again
        weights.pop()

def load_model_hook(models, input_dir):
    if args.use_ema:
        load_model = EMAModel.from_pretrained(
            os.path.join(input_dir, "unet_ema"), UNet2DConditionModel
        )
        ema_unet.load_state_dict(load_model.state_dict())
        ema_unet.to(accelerator.device)
        del load_model

    for i in range(len(models)):
        # pop models so that they are not loaded again
        model = models.pop()

        # load diffusers style into model
        load_model = UNet2DConditionModel.from_pretrained(
            input_dir, subfolder="unet"
        )
        model.register_to_config(**load_model.config)

        model.load_state_dict(load_model.state_dict())
        del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

<torch.utils.hooks.RemovableHandle at 0x7f26811dbdd0>

In [5]:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(f"{args.output_dir}/png", exist_ok=True)

In [6]:
noise_scheduler = DDPMScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler"
)

tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="tokenizer",
)

text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="text_encoder",
)

vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)

unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="unet",
)

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

if args.use_ema:
    ema_unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet"
    )

    ema_unet = EMAModel(
        ema_unet.parameters(),
        model_cls=UNet2DConditionModel,
        model_config=ema_unet.config,
    )


In [12]:
def tokenize_captions(captions, tokenizer):
    inputs = tokenizer(
        captions,
        max_length=tokenizer.model_max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    return inputs.input_ids

In [13]:
preprocess = transforms.Compose(
    [
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

In [14]:
dataset_dict = {"image": [], "text": []}
for file in os.listdir(args.data_dir):
    dataset_dict["image"].append(Image.open(args.data_dir / file))
    dataset_dict["text"].append(file.split("_")[0])

dataset = Dataset.from_dict(dataset_dict)

In [15]:
def transform(examples):
    images = [image.convert("RGB") for image in examples["image"]]
    examples["pixel_values"] = [preprocess(image) for image in images]
    examples["input_ids"] = [tokenize_captions(caption, tokenizer) for caption in examples["text"]]
    return examples

dataset.set_transform(transform)

In [16]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn)

In [17]:
def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

In [18]:
def decode_noise(noise, vae):
    with torch.no_grad():
        return vae.decode(noise / vae.config.scaling_factor, return_dict=False)[0]

def denormalize(images):
    """
    Denormalize an image array to [0,1].
    """
    return (images / 2 + 0.5).clamp(0, 1)

def numpy_to_pil(image):
    return Image.fromarray((image * 255).round().astype("uint8"))

def noise_to_pil(noise, vae):
    return numpy_to_pil(denormalize(decode_noise(noise, vae)).cpu().permute(0, 2, 3, 1)[0].numpy())

In [19]:
unet.enable_gradient_checkpointing()

In [20]:
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

In [21]:
# Prepare everything with our `accelerator`.
text_encoder, vae, unet, optimizer, train_dataloader = accelerator.prepare(
    text_encoder, vae, unet, optimizer, train_dataloader
)
if args.use_ema:
    ema_unet.to(accelerator.device)

In [17]:
def inference(prompt, tokenizer, text_encoder, unet, scheduler, device):
    text_input_ids = tokenize_captions(["", prompt], tokenizer)

    prompt_embeds = text_encoder(text_input_ids.squeeze(1))
    prompt_embeds = prompt_embeds[0]

    noise = torch.randn(
        (1, unet.config.in_channels, unet.config.sample_size, unet.config.sample_size)
    ).to(device)
    with torch.no_grad():
        progress_bar = tqdm(scheduler.timesteps)
        for t in progress_bar:
            noise_input = torch.cat([noise] * 2)

            model_output = unet(
                noise_input,
                t,
                encoder_hidden_states=prompt_embeds,
            )[0]
            noise_pred_uncond, noise_pred_text = model_output.chunk(2)
            noise_pred = noise_pred_uncond + args.guidance_scale * (
                    noise_pred_text - noise_pred_uncond
            )

            noise = scheduler.step(noise_pred, t, noise)[0]
    return noise

In [18]:
accelerator.init_trackers(args.output_dir)

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(dataset)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

text_encoder.to(accelerator.device, dtype=args.weight_dtype)
vae.to(accelerator.device, dtype=args.weight_dtype)

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])

        resume_global_step = global_step * args.gradient_accumulation_steps
        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = resume_global_step % (
            num_update_steps_per_epoch * args.gradient_accumulation_steps
        )

# Only show the progress bar once on each machine.
progress_bar = tqdm(
    range(global_step, args.max_train_steps),
    disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")

for epoch in range(first_epoch, args.num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        # Skip steps until we reach the resumed step
        if (
            args.resume_from_checkpoint
            and epoch == first_epoch
            and step < resume_step
        ):
            if step % args.gradient_accumulation_steps == 0:
                progress_bar.update(1)
            continue

        with accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(
                batch["pixel_values"].to(args.weight_dtype)
            ).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bsz,),
                device=latents.device,
            )
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"].squeeze(1))[0]

            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {noise_scheduler.config.prediction_type}"
                )

            # Predict the noise residual and compute loss
            model_pred = unet(
                noisy_latents, timesteps, encoder_hidden_states
            ).sample

            loss = F.mse_loss(
                model_pred.float(), target.float(), reduction="mean"
            )

            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
            train_loss += avg_loss.item() / args.gradient_accumulation_steps

            # Backpropagate
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            if args.use_ema:
                ema_unet.step(unet.parameters())
            progress_bar.update(1)
            global_step += 1
            accelerator.log({"train_loss": train_loss}, step=global_step)
            train_loss = 0.0

            if global_step % args.checkpointing_steps == 0:
                if accelerator.is_main_process:
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        checkpoints = os.listdir(args.output_dir)
                        checkpoints = [
                            d for d in checkpoints if d.startswith("checkpoint")
                        ]
                        checkpoints = sorted(
                            checkpoints, key=lambda x: int(x.split("-")[1])
                        )

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= args.checkpoints_total_limit:
                            num_to_remove = (
                                len(checkpoints) - args.checkpoints_total_limit + 1
                            )
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(
                                f"removing checkpoints: {', '.join(removing_checkpoints)}"
                            )

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(
                                    args.output_dir, removing_checkpoint
                                )
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(
                        args.output_dir, f"checkpoint-{global_step}"
                    )
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")

        logs = {
            "step_loss": loss.detach().item(),
            "lr": lr_scheduler.get_last_lr()[0],
        }
        progress_bar.set_postfix(**logs)

        if global_step >= args.max_train_steps:
            break

    if (
        args.validation_prompts is not None
        and epoch % args.validation_epochs == 0
    ):
        if args.use_ema:
            # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
            ema_unet.store(unet.parameters())
            ema_unet.copy_to(unet.parameters())
        
        noise = inference(
            args.validation_prompts[0],
            tokenizer,
            text_encoder,
            unet,
            noise_scheduler,
            accelerator.device,
        )
        im = noise_to_pil(noise)
        im.save(f"{args.output_dir}/png/{epoch}.png")

        if args.use_ema:
            # Switch back to the original UNet parameters.
            ema_unet.restore(unet.parameters())

unet = accelerator.unwrap_model(unet)
if args.use_ema:
    ema_unet.copy_to(unet.parameters())
torch.save(unet.state_dict(), f"{args.output_dir}/unet.pth")

accelerator.end_training()

Checkpoint 'latest' does not exist. Starting a new training run.


  0%|          | 0/1000 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 15.74 GiB total capacity; 14.19 GiB already allocated; 39.69 MiB free; 14.33 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF