<a href="https://colab.research.google.com/github/tpremoli/ADMIRE-DL/blob/master/ft_SD2_1_celeba_manual_textencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports & requirements


hf token


In [1]:
HF_HUB_TOKEN="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"

In [2]:
!pip install diffusers[training]==0.20.2 transformers matplotlib torch torchvision torchaudio

Collecting diffusers[training]==0.20.2
  Downloading diffusers-0.20.2.tar.gz (989 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/989.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m276.5/989.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━[0m [32m829.4/989.1 kB[0m [31m12.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m989.1/989.1 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.11.0 (from diffusers[training]==0.20.2)
  Downloading accelerate-0.27.2-py3-none-any.whl (279 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m14.6 MB

In [3]:
import math
import os
import random
import shutil

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from ast import literal_eval as parse_list_from_string
from dataclasses import dataclass
from datasets import load_dataset
from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline,
                       StableDiffusionPipeline, UNet2DConditionModel)

from diffusers.image_processor import VaeImageProcessor
from diffusers.optimization import get_scheduler
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.utils import randn_tensor
from huggingface_hub import HfApi
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm.auto import tqdm
from typing import List, Optional, Tuple, Union


# TODO: add checkpoints?
from torch.utils.checkpoint import checkpoint

# Some utils

In [4]:
def cantor_pairing(index, value):
    """
    This function takes an index and a value (-1 or 1), adjusts the value to be non-negative,
    and applies the Cantor pairing function to generate a unique number.
    """
    # Map -1 to 0 and 1 to 1 to ensure the value is non-negative
    adjusted_value = 0 if value == -1 else 1

    # Apply the Cantor pairing function
    return (index + adjusted_value) * (index + adjusted_value + 1) // 2 + adjusted_value


# NOTE: should this be done when preprocessing?
def tokenize_attributes_to_tensor_list(examples, use_custom_prompt_encoder=True):
    tensors = []
    for row in examples["prompt_string"]:
        prompt_tensor = tokenize_attribute_string_to_tensor(row, use_custom_prompt_encoder)
        tensors.append(prompt_tensor)
    return tensors

def tokenize_attribute_string_to_tensor(attr_string, use_custom_prompt_encoder=True):
    attr_row = parse_list_from_string(attr_string)
    if use_custom_prompt_encoder:
        final_attrs = [i for i in attr_row]
    else:
        final_attrs = [cantor_pairing(index, attr) for index, attr in enumerate(attr_row)]
    prompt_tensor = torch.LongTensor(final_attrs)
    return prompt_tensor

def collate_pixels_and_prompts(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    prompt = torch.stack([example["prompt"] for example in examples])
    prompt_str = [example["prompt_str"] for example in examples]
    return {"pixel_values": pixel_values, "prompt": prompt, "prompt_str": prompt_str}


# Our custom DiffusionPipeline


In [5]:
class AttributePromptEncoder():
    def __init__(self, dtype=torch.float16, device='cuda:0'):
        self.device = device
        self.dtype = dtype
        self.embedding_size = 1024

    def __call__(self, prompts, return_dict=False):
        # Ensure prompts is a tensor of shape [batch_size, sequence_length]
        prompts = prompts.to(self.device)

        batch_size, sequence_length = prompts.shape

        # Create embeddings for -1 and 1, and select based on prompts
        # Embeddings: -0.5 for -1, and 0.5 for 1
        embeddings = torch.where(prompts.unsqueeze(-1) == 1,
                                 torch.full((1,), 0.5, device=self.device),
                                 torch.full((1,), -0.5, device=self.device))

        # Ensure the output shape is [batch_size, sequence_length, embedding_size]
        embeddings = embeddings.expand(-1, -1, self.embedding_size)

        if return_dict:
            return {"hidden_states": embeddings}

        return embeddings

class ClassConditionalDiffusionPipeline(DiffusionPipeline):
    r"""
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Parameters:
        unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
            [`DDPMScheduler`], or [`DDIMScheduler`].
    """

    def __init__(self, vae, unet, scheduler, prompt_encoder):
        super().__init__()
        self.register_modules(vae=vae, unet=unet, scheduler=scheduler, prompt_encoder=prompt_encoder)
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)


    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents

    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt=1,
        use_custom_prompt_encoder=True,
        prompt_embeds=None,
    ):

        if isinstance(prompt, str):
            prompt_tensor = tokenize_attribute_string_to_tensor(prompt).to(device)
        else:  # If not a string, assume it's already a tensor and just move to the device
            prompt_tensor = prompt.to(device)

        if prompt_embeds is None:
            output = self.prompt_encoder(prompt_tensor)
            # Assuming you want the last hidden state, but adjust according to your needs
            if use_custom_prompt_encoder:
                prompt_embeds = output
            else:
                prompt_embeds = output[0]

        prompt_embeds = prompt_embeds.to(dtype=self.prompt_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        return prompt_embeds


    @torch.no_grad()
    def __call__(
        self,
        prompt,
        batch_size: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        num_inference_steps: int = 1000,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        latents: Optional[torch.FloatTensor] = None,
        use_progress_bar: bool = False
    ) -> Union[ImagePipelineOutput, Tuple]:

        # 0. Default height and width to unet
        height = self.unet.config.sample_size * self.vae_scale_factor
        width = self.unet.config.sample_size * self.vae_scale_factor

        # 3. Encode input prompt
        prompt_embeds = self._encode_prompt(
            prompt,
            self.device,
            # do_classifier_free_guidance, TODO: look into classifier free guidance?
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            self.device,
            generator,
            latents,
        )

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        progress_bar = None
        if use_progress_bar:
            progress_bar = self.progress_bar(total=num_inference_steps)
            progress_bar.__enter__()
        try:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = latents # here we would use classifier free guidance
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)


                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    return_dict=False,
                )[0]

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
                if use_progress_bar:
                    progress_bar.update(1)
        finally:
            if use_progress_bar:
                progress_bar.__exit__(None, None, None)

        # Convert latents to half precision before decoding TODO: why do we have to half latents
        latents_half = latents.half() / self.vae.config.scaling_factor
        image = self.vae.decode(latents_half, return_dict=False)[0]
        do_denormalize = [True] * image.shape[0]
        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

def make_grid(images, cols=None):
    num_images = len(images)

    if num_images == 0:
        raise ValueError("No images provided to make a grid.")

    # Assuming the tensors are in C x H x W format
    c, h, w = images[0].size()

    # Convert tensors to PIL Images for visualization
    # Note: This assumes the tensors are normalized. Adjust as necessary.
    pil_images = [Image.fromarray(img.permute(1, 2, 0).mul(255).byte().numpy()) for img in images]

    # Infer cols if not provided, for a square-ish grid
    if cols is None:
        cols = int(math.ceil(num_images**0.5))

    # Calculate rows based on the total number of images and the number of columns
    rows = int(math.ceil(num_images / float(cols)))
    grid = Image.new('RGB', size=(cols * w, rows * h))

    for i, img in enumerate(pil_images):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid

def evaluate(config, epoch, pipeline, val_dataloader, samples_subdir="samples"):
    images = []  # Initialize an empty list to store generated images
    generated_count = 0  # Counter for the number of generated images
    processed_prompts = []

    # Iterate over the validation dataloader to get prompts for image generation
    for batch in val_dataloader:
        prompt = batch["prompt"]  # Adjust if your data structure is different

        # Calculate batch size based on remaining images to generate
        current_batch_size = min(len(prompt), config.eval_batch_size - generated_count)

        if current_batch_size <= 0:
            break  # Stop if we have generated enough images

        with torch.no_grad():  # Ensure no gradients are computed during evaluation
            return_dict = True
            output = pipeline(
                prompt=prompt,
                batch_size=current_batch_size,
                generator=torch.manual_seed(config.seed) if config.seed else None,
                num_inference_steps=conf.num_inference_steps,  # Assuming you want to reduce steps for faster inference
                output_type="pil",
                return_dict=return_dict
            )

            processed_prompts.extend(batch["prompt_str"])

            if return_dict:
                images.extend(output.images[:current_batch_size])  # Limit images to batch size
            else:
                images.extend(list(output)[:current_batch_size])

            generated_count += current_batch_size  # Update the counter

            # Check if the desired number of images is generated
            if len(images) >= config.eval_batch_size:
                break

    # Convert PIL images to PyTorch tensors for use with make_grid
    image_tensors = [torch.tensor(np.array(img)).permute(2, 0, 1) / 255. for img in images]  # Normalize if necessary
    image_grid = make_grid(image_tensors)  # Create a grid of images

    # dir creation etc
    samplesdir = os.path.join(config.output_dir, samples_subdir)
    os.makedirs(samplesdir, exist_ok=True)

    img_grid_dir = os.path.join(samplesdir, f"{(epoch+1):03d}.png")
    prompt_txt_dir = os.path.join(samplesdir, f"prompts-{(epoch+1):03d}.txt")

    # Directly save the image grid without tensor manipulation
    image_grid.save(img_grid_dir)

    with open(prompt_txt_dir, "w") as f:
        for prompt in processed_prompts:
            f.write(str(prompt))
            f.write("\n")

    return img_grid_dir,  prompt_txt_dir

# Training Loop


In [6]:
def train(conf):
    # load conf & prep hf
    hf_api = HfApi(token=HF_HUB_TOKEN) # set the token

    # start accelerator for everything
    logging_dir = os.path.join(conf.output_dir, conf.logging_dir)
    accelerator_project_config = ProjectConfiguration(project_dir=conf.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=conf.gradient_accumulation_steps,
        mixed_precision=conf.mixed_precision,
        log_with="tensorboard",
        project_config=accelerator_project_config,
    )

    # loading the dataset and preprocessing it
    # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata
    # more info on metadata
    print("loading dataset")
    if conf.dataset_local_dir != "":
        dataset = load_dataset(
            "imagefolder",
            data_dir=conf.dataset_local_dir
        )
    elif conf.dataset_hf_dir != "":
        dataset = load_dataset(conf.dataset_hf_dir)
    else:
        raise ValueError("No valid dataset directory provided")
    # NOTE: can be accessed from hf hub

    with accelerator.main_process_first():
        train_dataset = dataset["train"].with_transform(conf.preprocess_train)
        val_dataset = dataset["validation"].with_transform(conf.preprocess_train)

    print("creating dataloader")
    # Creatubg the dataloader
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_pixels_and_prompts,
        batch_size=conf.train_batch_size,
        num_workers=conf.dataloader_num_workers,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        shuffle=False,
        collate_fn=collate_pixels_and_prompts,
        batch_size=1,
        num_workers=conf.dataloader_num_workers,
    )

    print("loading VAE")
    # 1. Load the autoencoder model which will be used to decode the latents into image space.
    vae = AutoencoderKL.from_pretrained(conf.pretrained_hf_model, subfolder="vae")

    print("loading u-net")
    # 2. getting u-net
    unet = UNet2DConditionModel.from_pretrained(conf.pretrained_hf_model, subfolder="unet")

    if conf.use_custom_prompt_encoder:
        print("Using custom prompt encoder")
        prompt_encoder = AttributePromptEncoder()
    else:
        print("loading CLIPTextModel")
        # 2.5. getting prompt encoder
        # prompt_encoder = AttributePromptEncoder.from_pretrained(conf.APE_dir)
        prompt_encoder = CLIPTextModel.from_pretrained(conf.pretrained_hf_model, subfolder="text_encoder")

    print("loading lr scheduler")
    # 3. creating lr scheduler (custom)
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / conf.gradient_accumulation_steps)
    max_train_steps = conf.num_epochs * num_update_steps_per_epoch
    optimizer = torch.optim.AdamW(
        unet.parameters(),
        lr=conf.learning_rate,
        # betas=(conf.adam_beta1, conf.adam_beta2),
        # weight_decay=conf.adam_weight_decay,
        # eps=conf.adam_epsilon,
    )
    lr_scheduler = get_scheduler(
        conf.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=conf.lr_warmup_steps * accelerator.num_processes,
        num_training_steps=max_train_steps * accelerator.num_processes,
    )

    print("loading DDPMScheduler")
    # 4. getting noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(
        conf.pretrained_hf_model, subfolder="scheduler"
    )

    print("disabling training of vae and prompt_encoder")
    # disabling training of vae and prompt_encoder
    vae.requires_grad_(False)
    unet.train()
    if not conf.use_custom_prompt_encoder:
        prompt_encoder.requires_grad_(False)

    print("prepping accelerator")
    # send everything to the accelerator
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
        conf.mixed_precision = accelerator.mixed_precision
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        conf.mixed_precision = accelerator.mixed_precision

    # Moving vae and prompt_encoder to accelerator
    if not conf.use_custom_prompt_encoder:
        prompt_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)

    print("saving untrained models")
    # Check if the output directory exists, create it if not
    if not os.path.exists(conf.output_dir):
        os.makedirs(conf.output_dir, exist_ok=True)

    # NOTE: we don't train these models, so we can just save them and don't have
    # to upload them every time
    vae.save_pretrained(f"{conf.output_dir}/vae")
    noise_scheduler.save_pretrained(f"{conf.output_dir}/scheduler")
    if not conf.use_custom_prompt_encoder:
        prompt_encoder.save_pretrained(f"{conf.output_dir}/prompt_encoder")

    if conf.push_to_hub:
        print("uploading untrained models to hf")
        hf_api.upload_folder(
            repo_id=conf.hf_repo_id,
            folder_path=f"{conf.output_dir}/vae",
            path_in_repo="vae",
            commit_message=f"init commit of untrained vae",
        )
        hf_api.upload_folder(
            repo_id=conf.hf_repo_id,
            folder_path=f"{conf.output_dir}/scheduler",
            path_in_repo="scheduler",
            commit_message=f"init commit of untrained scheduler",
        )
        if not conf.use_custom_prompt_encoder:
            hf_api.upload_folder(
                repo_id=conf.hf_repo_id,
                folder_path=f"{conf.output_dir}/prompt_encoder",
                path_in_repo="prompt_encoder",
                commit_message=f"init commit of untrained prompt_encoder",
            )

    first_epoch = 0
    if conf.resuming_from_checkpoint:
        # first_epoch will be the epoch of the checkpoint we load
        pass

    print("Starting Training  loop")
    # progress bar definition
    progress_bar = tqdm(
        range(0, conf.num_epochs * num_update_steps_per_epoch),
        initial=0,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    glob_step = 0

    def unet_forward_with_checkpointing(noisy_latents, timesteps, encoder_hidden_states):
        # This function will be called with checkpointing to save memory
        return unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

    # Training loop
    for epoch in range(first_epoch, conf.num_epochs):
        train_loss = 0.0

        # this is for tracking the avg loss per epoch
        loss_this_epoch = 0
        num_steps = 0

        # Update the progress bar's description to include the current epoch
        progress_bar.set_description(f"Epoch {epoch+1}/{conf.num_epochs}. Steps")
        for step, batch in enumerate(train_dataloader):
            num_steps +=1

            # Use the accelerator context manager for the unet model
            with accelerator.accumulate(unet):
                # NOTE: batch["pixel_values"] is a list of tensors, one for each image in the batch
                # NOTE: batch["prompt"] is a list of tensors, one for each attribute row in the batch

                # Convert images to latent space
                latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                noisy_latents.requires_grad_(True) #TODO fix this

                # Target:
                target = noise
                # target can be v_prediction too:
                # target = noise_scheduler.get_velocity(latents, noise, timesteps)

                # Encode the prompt.
                if conf.use_custom_prompt_encoder:
                    encoder_hidden_states = prompt_encoder(batch["prompt"], return_dict=False)
                else:
                    encoder_hidden_states = prompt_encoder(batch["prompt"], return_dict=False)[0]

                # Calculate the model outputs
                # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
                model_pred = checkpoint(unet_forward_with_checkpointing, noisy_latents, timesteps, encoder_hidden_states)

                #  maybe snr loss here instead of mse
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(conf.train_batch_size)).mean()
                train_loss += avg_loss.item() / conf.gradient_accumulation_steps
                loss_this_epoch += avg_loss.item()

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()


            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                glob_step += 1
                accelerator.log({"train_loss": train_loss}, step=glob_step)
                train_loss = 0.0

            # Calculate the running average loss at the current step
            current_avg_loss = loss_this_epoch / num_steps if num_steps else 0
            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "avg_loss": current_avg_loss}
            progress_bar.set_postfix(**logs)

            if glob_step >= max_train_steps:
                break

        # HERE WE're DONE WITH THE EPOCH

        # avg losss per epoch
        avg_loss_per_epoch = loss_this_epoch / num_steps if num_steps else 0
        print(f"Average loss for epoch {epoch+1}: {avg_loss_per_epoch:.4f}")
        loss_this_epoch = 0
        num_steps = 0

        # POST TRAINING LOOP: check if we need to save model or images
        if accelerator.is_main_process:

            # STEP 1: CHECK IF WE NEED TO SAVE THE MODEL NOW
            if (epoch+1) % conf.save_model_epochs == 0 or epoch == conf.num_epochs - 1:
                print("Saving model")

                if accelerator.is_main_process:

                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if conf.max_checkpoints_saved is not None:
                        checkpoints = os.listdir(conf.output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= conf.max_checkpoints_saved:
                            print("clearing unecessary saved checkpoints")
                            num_to_remove = len(checkpoints) - conf.max_checkpoints_saved + 1
                            removing_checkpoints = checkpoints[:num_to_remove]

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint_path = os.path.join(conf.output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint_path)

                    save_path = os.path.join(conf.output_dir, f"checkpoint-{epoch+1}")
                    unet.save_pretrained(f"{save_path}/unet")  # saves the state of the unet

                    print("model saved")

                    if conf.push_to_hub:
                        print("uploading model to hf")
                        hf_api.upload_folder(
                            repo_id=conf.hf_repo_id,
                            folder_path=save_path,
                            path_in_repo=f"checkpoints/checkpoint-{epoch+1}",
                            commit_message=f"Saving weights and logs of epoch {epoch+1}",
                        )

            # STEP 2: CHECK IF WE NEED TO SAVE IMAGES NOW
            if ((epoch + 1) % conf.save_image_epochs == 0 or
                    epoch == conf.num_epochs - 1 or
                    (epoch+1) % conf.save_model_epochs == 0): # making sure we save an img at each saved model
                print("saving images")

                with accelerator.main_process_first():
                    val_dataset = dataset["validation"].with_transform(conf.preprocess_train)

                # Creatubg the dataloader
                val_dataloader = torch.utils.data.DataLoader(
                    val_dataset,
                    shuffle=conf.shuffle_val_data,
                    collate_fn=collate_pixels_and_prompts,
                    batch_size=1,
                    num_workers=conf.dataloader_num_workers,
                )

                # we load our pipeline
                pipeline = ClassConditionalDiffusionPipeline(
                    vae=accelerator.unwrap_model(vae),
                    unet=accelerator.unwrap_model(unet),
                    scheduler=noise_scheduler,
                    prompt_encoder=prompt_encoder,
                )

                imgfile, prompttxt = evaluate(conf, epoch, pipeline, val_dataloader)

                if conf.push_to_hub:
                    hf_api.upload_file(
                        repo_id=conf.hf_repo_id,
                        path_or_fileobj=imgfile,
                        path_in_repo=f"samples/{os.path.basename(imgfile)}",
                        commit_message=f"Uploading img sample of epoch {epoch+1}",
                    )
                    hf_api.upload_file(
                        repo_id=conf.hf_repo_id,
                        path_or_fileobj=prompttxt,
                        path_in_repo=f"samples/{os.path.basename(prompttxt)}",
                        commit_message=f"Uploading prompt sample of epoch {epoch+1}",
                    )


    # save a final checkpoint
    accelerator.end_training()



# Training Config and Launching training (20k)


In [None]:
@dataclass
class TrainingConfig20k:
    # This model is 128x128, which is what we're aiming for
    pretrained_hf_model = "bguisard/stable-diffusion-nano-2-1"

    # the generated image resolution
    image_size = 128
    train_batch_size = 128
    eval_batch_size = 9  # how many images to sample during evaluation
    dataloader_num_workers = 2
    gradient_accumulation_steps = 1

    num_epochs = 25
    learning_rate = 1e-6
    lr_warmup_steps = 500

    num_inference_steps = 150

    shuffle_val_data = True
    save_image_epochs = 1
    save_model_epochs = 5
    max_checkpoints_saved=5

    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'LDA-CelebA-128-20k'  # the model namy locally and on the HF Hub

    dataset_local_dir = ""
    dataset_hf_dir = "tpremoli/CelebA-attrs-20k"

    hf_repo_id="tpremoli/LDA-CelebA-128-20k"
    push_to_hub = True  # whether to upload the saved model to the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook

    resuming_from_checkpoint = False  # set to True to continue training from a saved checkpoint

    logging_dir="LDA-CelebA-128-20k/logs"

    # Choose between
    # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]'
    lr_scheduler="cosine_with_restarts"

    seed = None

    train_transforms = transforms.Compose(
        [
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(self, examples):
        images = [image.convert("RGB") for image in examples["image"]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["prompt_str"] = [ps for ps in examples["prompt_string"]]
        examples["prompt"] = tokenize_attributes_to_tensor_list(examples)
        return examples

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")
conf = TrainingConfig20k()
train(conf)

# Training Config and Launching Training (80k)

In [7]:
@dataclass
class TrainingConfig80k:
    # This model is 128x128, which is what we're aiming for
    pretrained_hf_model = "bguisard/stable-diffusion-nano-2-1"

    # the generated image resolution
    image_size = 128
    train_batch_size = 256
    eval_batch_size = 9  # how many images to sample during evaluation
    dataloader_num_workers = 2
    gradient_accumulation_steps = 1

    use_custom_prompt_encoder=True

    num_epochs = 50
    learning_rate = 1e-5
    lr_warmup_steps = 500

    num_inference_steps = 50

    shuffle_val_data = False
    save_image_epochs = 2
    save_model_epochs = 5
    max_checkpoints_saved=5

    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'LDA-CelebA-128-80k-customPE'  # the model namy locally and on the HF Hub

    dataset_local_dir = ""
    dataset_hf_dir = "tpremoli/CelebA-attrs-80k"

    hf_repo_id="tpremoli/LDA-CelebA-128-80k-customPE"
    push_to_hub = True  # whether to upload the saved model to the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook

    resuming_from_checkpoint = False  # set to True to continue training from a saved checkpoint

    logging_dir="LDA-CelebA-128-80k-customPE/logs"

    # Choose between
    # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]'
    lr_scheduler="cosine_with_restarts"

    seed = None

    train_transforms = transforms.Compose(
        [
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(self, examples):
        images = [image.convert("RGB") for image in examples["image"]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["prompt_str"] = [ps for ps in examples["prompt_string"]]
        examples["prompt"] = tokenize_attributes_to_tensor_list(examples)
        return examples

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")
conf = TrainingConfig80k()
train(conf)

loading dataset


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/2.68k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/281M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/281M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/69.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/69.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/79999 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9810 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9763 [00:00<?, ? examples/s]

creating dataloader
loading VAE


vae/config.json:   0%|          | 0.00/582 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

loading u-net


unet/config.json:   0%|          | 0.00/912 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Using custom prompt encoder
loading lr scheduler
loading DDPMScheduler


scheduler/scheduler_config.json:   0%|          | 0.00/346 [00:00<?, ?B/s]

disabling training of vae and prompt_encoder
prepping accelerator
saving untrained models
uploading untrained models to hf
Starting Training  loop


Steps:   0%|          | 0/15650 [00:00<?, ?it/s]



Average loss for epoch 1: 0.1794
Average loss for epoch 2: 0.1639
saving images
Average loss for epoch 3: 0.1622
Average loss for epoch 4: 0.1620
saving images
Average loss for epoch 5: 0.1608
Saving model
model saved
uploading model to hf


diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

saving images


# Training Config and Launching Training (160k)

In [None]:
@dataclass
class TrainingConfig160k:
    # This model is 128x128, which is what we're aiming for
    pretrained_hf_model = "bguisard/stable-diffusion-nano-2-1"

    # the generated image resolution
    image_size = 128
    train_batch_size = 256
    eval_batch_size = 9  # how many images to sample during evaluation
    dataloader_num_workers = 2
    gradient_accumulation_steps = 1

    num_epochs = 75
    learning_rate = 25e-7
    lr_warmup_steps = 500

    num_inference_steps = 30

    shuffle_val_data = True
    save_image_epochs = 2
    save_model_epochs = 5
    max_checkpoints_saved=5

    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'LDA-CelebA-128-160k'  # the model namy locally and on the HF Hub

    dataset_local_dir = ""
    dataset_hf_dir = "tpremoli/CelebA-attrs-160k"

    hf_repo_id="tpremoli/LDA-CelebA-128-160k"
    push_to_hub = True  # whether to upload the saved model to the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook

    resuming_from_checkpoint = False  # set to True to continue training from a saved checkpoint

    logging_dir="LDA-CelebA-128-160k/logs"

    # Choose between
    # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]'
    lr_scheduler="cosine_with_restarts"

    seed = None

    train_transforms = transforms.Compose(
        [
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(self, examples):
        images = [image.convert("RGB") for image in examples["image"]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["prompt_str"] = [ps for ps in examples["prompt_string"]]
        examples["prompt"] = tokenize_attributes_to_tensor_list(examples)
        return examples

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")
conf = TrainingConfig160k()
train(conf)

# Getting attrs from prompt

In [None]:
def get_attrs_from_prompt(prompt_string):
    # List of attribute names as per the order mentioned in the dataset
    attribute_names = [
        "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs",
        "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows",
        "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones",
        "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin",
        "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair",
        "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace",
        "Wearing_Necktie", "Young"
    ]

    # Convert the prompt string into a list of integers
    # Assuming prompt_string is a string representation of a list, e.g., "[-1, 1, -1, ...]"
    attr_values = eval(prompt_string)

    # Construct the output list by checking each value and prefixing the attribute names accordingly
    attrs_output = []
    for i, value in enumerate(attr_values):
        prefix = "YES" if value == 1 else "NO"
        attrs_output.append(f"{prefix} {attribute_names[i]}")

    return attrs_output

# Example usage
prompt_string = "[-1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1]"
attrs_output = get_attrs_from_prompt(prompt_string)
print(attrs_output)

# Closing colab session


In [None]:
from google.colab import runtime
runtime.unassign()

# Testing pipeline

In [None]:
@dataclass
class TestConfig:
    # This model is 128x128, which is what we're aiming for
    pretrained_hf_model = "bguisard/stable-diffusion-nano-2-1"

    # the generated image resolution
    image_size = 128
    train_batch_size = 128
    eval_batch_size = 9  # how many images to sample during evaluation
    dataloader_num_workers = 2
    gradient_accumulation_steps = 1

    num_epochs = 25
    learning_rate = 1e-6
    lr_warmup_steps = 500

    num_inference_steps = 150

    shuffle_val_data = True
    save_image_epochs = 1
    save_model_epochs = 5
    max_checkpoints_saved=5

    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'LDA-CelebA-128-80k'  # the model namy locally and on the HF Hub

    dataset_local_dir = ""
    dataset_hf_dir = "tpremoli/CelebA-attrs-80k"

    hf_repo_id="tpremoli/LDA-CelebA-128-80k"
    push_to_hub = True  # whether to upload the saved model to the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook

    resuming_from_checkpoint = False  # set to True to continue training from a saved checkpoint

    logging_dir="LDA-CelebA-128-80k/logs"

    # Choose between
    # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]'
    lr_scheduler="cosine_with_restarts"

    seed = None

    train_transforms = transforms.Compose(
        [
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(self, examples, use_custom_prompt_encoder=True):
        images = [image.convert("RGB") for image in examples["image"]]
        examples["pixel_values"] = [self.train_transforms(image) for image in images]
        examples["prompt_str"] = [ps for ps in examples["prompt_string"]]
        examples["prompt"] = tokenize_attributes_to_tensor_list(examples, use_custom_prompt_encoder)
        examples["prompt_unscaled"] = torch.LongTensor([parse_list_from_string(ps) for ps in examples["prompt_string"]])
        return examples

# test the pipeline:
def test_pipeline_eval(conf, num_of_images=9, use_hf_model=True, checkpoint=14):
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
    conf.eval_batch_size = num_of_images

    # load conf & prep hf
    hf_api = HfApi(token=HF_HUB_TOKEN) # set the token

    # start accelerator for everything
    logging_dir = os.path.join(conf.output_dir, conf.logging_dir)
    accelerator_project_config = ProjectConfiguration(project_dir=conf.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=conf.gradient_accumulation_steps,
        mixed_precision=conf.mixed_precision,
        log_with="tensorboard",
        project_config=accelerator_project_config,
    )

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
        conf.mixed_precision = accelerator.mixed_precision
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        conf.mixed_precision = accelerator.mixed_precision

    # loading the dataset and preprocessing it
    # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata
    # more info on metadata
    print("loading dataset")
    if conf.dataset_local_dir != "":
        val_dataset = load_dataset(
            "imagefolder",
            data_dir=conf.dataset_local_dir, split="validation"
        )
    elif conf.dataset_hf_dir != "":
        val_dataset = load_dataset(conf.dataset_hf_dir, split="validation")
    else:
        raise ValueError("No valid dataset directory provided")
    # NOTE: can be accessed from hf hub

    with accelerator.main_process_first():
        val_dataset = val_dataset.with_transform(conf.preprocess_train)

    print("creating dataloader")
    # Creatubg the dataloader
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        shuffle=False,
        collate_fn=collate_pixels_and_prompts_test,
        batch_size=1,
        num_workers=conf.dataloader_num_workers,
    )

    print("loading VAE")
    # 1. Load the autoencoder model which will be used to decode the latents into image space.
    if use_hf_model:
        vae = AutoencoderKL.from_pretrained(conf.hf_repo_id, subfolder="vae", use_auth_token=HF_HUB_TOKEN)
    else:
        vae = AutoencoderKL.from_pretrained(f"{conf.output_dir}/vae")

    print("loading u-net")
    # 2. getting u-net
    if use_hf_model:
        unet = UNet2DConditionModel.from_pretrained(conf.hf_repo_id, subfolder=f"checkpoints/checkpoint-{checkpoint}/unet", use_auth_token=HF_HUB_TOKEN)
    else:
        unet = UNet2DConditionModel.from_pretrained(f"{conf.output_dir}/checkpoints/checkpoint-{checkpoint}/unet")


    print("loading CLIPTextModel")
    # 2.5. getting prompt encoder
    # prompt_encoder = AttributePromptEncoder.from_pretrained(conf.APE_dir)
    if use_hf_model:
        prompt_encoder = CLIPTextModel.from_pretrained(conf.hf_repo_id, subfolder="prompt_encoder", use_auth_token=HF_HUB_TOKEN)
    else:
        prompt_encoder = CLIPTextModel.from_pretrained(f"{conf.output_dir}/prompt_encoder")

    print("loading DDPMScheduler")
    # 4. getting noise scheduler
    if use_hf_model:
        noise_scheduler = DDPMScheduler.from_pretrained(conf.hf_repo_id, subfolder="scheduler", use_auth_token=HF_HUB_TOKEN)
    else:
        noise_scheduler = DDPMScheduler.from_pretrained(f"{conf.output_dir}/scheduler")

    print("disabling training of everything")
    # disabling training of vae and prompt_encoder
    vae.requires_grad_(False)
    prompt_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    # Moving vae and prompt_encoder to accelerator
    unet.to(accelerator.device, dtype=weight_dtype)
    prompt_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)

    print("prepping accelerator")
    # send everything to the accelerator
    unet, val_dataloader = accelerator.prepare(
        unet, val_dataloader
    )

    print("loading pipeline to save images")
    # we load our pipeline
    pipeline = ClassConditionalDiffusionPipeline(
        vae=accelerator.unwrap_model(vae),
        unet=accelerator.unwrap_model(unet),
        scheduler=noise_scheduler,
        prompt_encoder=prompt_encoder
    )

    #evaluate(conf, 0, pipeline, val_dataloader)
    """
    print("testing reverse clip tokenizer")
    tokenizer = CLIPTokenizer.from_pretrained(conf.pretrained_hf_model, subfolder="tokenizer", use_auth_token=HF_HUB_TOKEN)

    lim = 5
    for batch in val_dataloader:
        prompt_str = ''.join(map(str, (tokenizer._convert_id_to_token(i) for i in batch["prompt"][0].tolist())))
        #print(batch["prompt"][0])
        #print(prompt_str)
        #print(tokenizer.convert_tokens_to_string(prompt_str))
        lim -= 1
        if lim == 0:
            break
    """

    # NOTE: test with each individual token
    lim = 5
    for batch in val_dataloader:
        print(batch["prompt_unscaled"])
        print(prompt_encoder(batch["prompt"]).last_hidden_state.size())
        lim -=1
        if lim == 0:
            break

    lim = 5
    for batch in val_dataloader:
        print(batch["prompt_unscaled"])
        print(embedded_attributes.size())
        print(embedded_attributes)
        lim-=1
        if lim == 0:
            break



conf = TestConfig()
test_pipeline_eval(conf)

In [10]:
accelerator_project_config = ProjectConfiguration(project_dir=conf.output_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=conf.gradient_accumulation_steps,
    mixed_precision=conf.mixed_precision,
    log_with="tensorboard",
    project_config=accelerator_project_config,
)

print("loading dataset")
if conf.dataset_local_dir != "":
    dataset = load_dataset(
        "imagefolder",
        data_dir=conf.dataset_local_dir
    )
elif conf.dataset_hf_dir != "":
    dataset = load_dataset(conf.dataset_hf_dir)
else:
    raise ValueError("No valid dataset directory provided")
# NOTE: can be accessed from hf hub

with accelerator.main_process_first():
    train_dataset = dataset["train"].with_transform(conf.preprocess_train)
    val_dataset = dataset["validation"].with_transform(conf.preprocess_train)

print("creating dataloader")
# Creatubg the dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_pixels_and_prompts,
    batch_size=conf.train_batch_size,
    num_workers=conf.dataloader_num_workers,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    collate_fn=collate_pixels_and_prompts,
    batch_size=1,
    num_workers=conf.dataloader_num_workers,
)

print("loading VAE")
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained(conf.pretrained_hf_model, subfolder="vae")

print("loading u-net")
# 2. getting u-net
unet = UNet2DConditionModel.from_pretrained(conf.pretrained_hf_model, subfolder="unet")

if conf.use_custom_prompt_encoder:
    print("Using custom prompt encoder")
    prompt_encoder = AttributePromptEncoder()
else:
    print("loading CLIPTextModel")
    # 2.5. getting prompt encoder
    # prompt_encoder = AttributePromptEncoder.from_pretrained(conf.APE_dir)
    prompt_encoder = CLIPTextModel.from_pretrained(conf.pretrained_hf_model, subfolder="text_encoder")

print("loading lr scheduler")
# 3. creating lr scheduler (custom)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / conf.gradient_accumulation_steps)
max_train_steps = conf.num_epochs * num_update_steps_per_epoch
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=conf.learning_rate,
    # betas=(conf.adam_beta1, conf.adam_beta2),
    # weight_decay=conf.adam_weight_decay,
    # eps=conf.adam_epsilon,
)
lr_scheduler = get_scheduler(
    conf.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=conf.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=max_train_steps * accelerator.num_processes,
)

print("loading DDPMScheduler")
# 4. getting noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(
    conf.pretrained_hf_model, subfolder="scheduler"
)

print("disabling training of vae and prompt_encoder")
# disabling training of vae and prompt_encoder
vae.requires_grad_(False)
unet.train()
if not conf.use_custom_prompt_encoder:
    prompt_encoder.requires_grad_(False)

print("prepping accelerator")
# send everything to the accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
    conf.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
    conf.mixed_precision = accelerator.mixed_precision

# Moving vae and prompt_encoder to accelerator
if not conf.use_custom_prompt_encoder:
    prompt_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

print("saving untrained models")
# Check if the output directory exists, create it if not
if not os.path.exists(conf.output_dir):
    os.makedirs(conf.output_dir, exist_ok=True)

# NOTE: we don't train these models, so we can just save them and don't have
# to upload them every time
vae.save_pretrained(f"{conf.output_dir}/vae")
noise_scheduler.save_pretrained(f"{conf.output_dir}/scheduler")
if not conf.use_custom_prompt_encoder:
    prompt_encoder.save_pretrained(f"{conf.output_dir}/prompt_encoder")


with accelerator.main_process_first():
    val_dataset = dataset["validation"].with_transform(conf.preprocess_train)

# Creatubg the dataloader
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=conf.shuffle_val_data,
    collate_fn=collate_pixels_and_prompts,
    batch_size=1,
    num_workers=conf.dataloader_num_workers,
)

# we load our pipeline
pipeline = ClassConditionalDiffusionPipeline(
    vae=accelerator.unwrap_model(vae),
    unet=accelerator.unwrap_model(unet),
    scheduler=noise_scheduler,
    prompt_encoder=prompt_encoder,
)

imgfile, prompttxt = evaluate(conf, 0, pipeline, val_dataloader)


loading dataset
creating dataloader
loading VAE
loading u-net
Using custom prompt encoder
loading lr scheduler
loading DDPMScheduler
disabling training of vae and prompt_encoder
prepping accelerator
saving untrained models


In [None]:
!nvidia-smi

# Clearing Torch memory


In [None]:
!nvidia-smi

In [None]:
!pip install numba



clear python memory

In [None]:
import gc
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
result = torch.tensor(1.0, device='cuda') / torch.tensor(0.0, device='cuda')

clear gpu memory

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()
from numba import cuda
device = cuda.get_current_device()
device.reset()
gc.collect()

In [None]:
!nvidia-smi

Fri Feb 23 00:31:55 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              49W / 400W |  28695MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!v