# Stable Diffusion finetuning
Ноутбук с примером обучения StableDiffusionInpaint для использования полученных весов в генераторе синтетических аугментаций

In [None]:
import os
import json
import math

import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from IPython.display import clear_output
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from tqdm.auto import tqdm
from datasets import Dataset, load_from_disk
from PIL import Image, ImageDraw
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler

from SyntheticAugmentationGenerator import AugmentationGenerator

## Подготовка данных для обучения
Для обучения StableDiffusionInpaint необходимо подготовить датасет, каждый экземпляр которого будет хранить оригинальное изображение, маску и промпт(текстовую подсказку для обучения). Такой датасет будет подготовлен из датасета в формате COCO для детекции объектов.

В данном разделе будут вырезаться изображения размером 512x512 пикселей вокруг размеченных bbox'ов и создаваться маска с белым прямоугольником на месте bbox'а. Для каждого изображения необходимо написать несколько текстовый подсказок, по которым будет учиться новая модель.
Формат полученного датасета выглядит так:

    {
        'images': list<PIL image>
        'masks': list<PIL image>
        'text': list<str>
    }

In [None]:
COCO_DIR = './data/example.json' #путь до JSON файла с COCO разметкой
IMGS_DIR = './data/example/' #путь до директории с изображениями

In [None]:
images = []
masks = []

with open(COCO_DIR, 'r') as f:
    coco = json.load(f)
    for ann in coco['annotations']:
            img_name = coco['images'][ann['image_id']]['file_name'].split('/')[-1]
            if os.path.exists(IMGS_DIR+img_name):
                img = Image.open(IMGS_DIR+img_name)
                w, h = img.size
                bbox = ann['bbox']
                bbox[2] += bbox[0]
                bbox[3] += bbox[1]
                att_area, mask, _, _ = AugmentationGenerator.generate_attention_area(img=img, bbox=bbox, aa_size=512)
                images += [att_area]
                masks += [mask]
                
len(images), len(masks)

In [None]:
dataset_dict = {
    'images': [],
    'masks': [],
    'texts': []
}

In [None]:
nn_prompts = 1 #количество промптов для каждого изображения
for i, img in enumerate(images):
    display(img)
    for _ in range(nn_prompts):
        prompt = input()
        dataset_dict['texts'] += [prompt]
        dataset_dict['images'] += [img]
        dataset_dict['masks'] += [masks[i]]
    clear_output()

In [None]:
DATASET_DIR = './dataset_example' #директория для сохранения датасета
inpaint_dataset = Dataset.from_dict(dataset_dict)
inpaint_dataset.save_to_disk(DATASET_DIR)

## Подготовка InpaintDataset

In [None]:
class InpaintDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, tokenizer, size=512):
        self.size = size
        self.tokenizer = tokenizer

        self.dataset_dir = Path(dataset_dir)
        if not self.dataset_dir.exists():
            raise ValueError("Dataset doesn't exists.")

        self.dataset = load_from_disk(dataset_dir)
        self.images = self.dataset['images']
        self.prompts = self.dataset['text']
        self.masks = self.dataset['masks']
        self.instance_images_path = list(Path(dataset_dir).iterdir())
        self.num_instance_images = len(self.images)
        self._length = self.num_instance_images
        self.image_transforms_resize_and_crop = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                # transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
            ]
        )

        self.image_transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image =self.images[index % self.num_instance_images]
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        instance_image = self.image_transforms_resize_and_crop(instance_image)

        example["PIL_images"] = instance_image
        example["instance_images"] = self.image_transforms(instance_image)

        example["instance_prompt_ids"] = self.tokenizer(
            self.prompts[index % self.num_instance_images],
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        example['masks'] = self.masks[index % self.num_instance_images]
        return example

In [None]:
def prepare_mask_and_masked_image(image, mask):
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    mask = np.array(mask.convert("L"))
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)
    masked_image = image * (mask < 0.5)

    return mask, masked_image

In [None]:
def collate_fn(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    masks = []
    masked_images = []
    for example in examples:
        pil_image = example["PIL_images"]
        mask = example["masks"]
        mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
        masks.append(mask)
        masked_images.append(masked_image)

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
    masks = torch.stack(masks)
    masked_images = torch.stack(masked_images)
    batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
    return batch

## Подготовка модели

In [None]:
pretrained_model_name_or_path = 'stabilityai/stable-diffusion-2-inpainting' #название модели для дообучения

tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

In [None]:
optimizer_class = torch.optim.AdamW
params_to_optimize = (
    unet.parameters()
)
optimizer = optimizer_class(
    params_to_optimize,
    lr=5e-6,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08
)
lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps * accelerator.num_processes,
)

In [None]:
output_dir = './sd_inpaint_finetune'
os.makedirs(output_dir, exist_ok=True)

logging_dir = Path(output_dir, 'logs')
project_config = ProjectConfiguration(
    project_dir=output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision="no",
    log_with="tensorboard",
    project_config=project_config,
)
lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps * accelerator.num_processes,
)

In [None]:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler)
accelerator.register_for_checkpointing(lr_scheduler)
weight_dtype = torch.float32

vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
accelerator.init_trackers("dreambooth")

## Обучение

In [None]:
data_dir = 'dataset_example' #название датасета полученного в первом разделе
train_batch_size = 1
max_train_steps = 400
resolution = 512

In [None]:
train_dataset = InpaintDataset(
    instance_data_root=data_dir,
    tokenizer=tokenizer,
    size=512,
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn
)

In [None]:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
total_batch_size = train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps

global_step = 0
first_epoch = 0

progress_bar = tqdm(range(global_step, max_train_steps))
progress_bar.set_description("Steps")

In [None]:
for epoch in range(first_epoch, num_train_epochs):
    unet.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            masked_latents = vae.encode(
                batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
            ).latent_dist.sample()
            masked_latents = masked_latents * vae.config.scaling_factor

            masks = batch["masks"]
            mask = torch.stack(
                [
                    torch.nn.functional.interpolate(mask, size=(resolution // 8, resolution // 8))
                    for mask in masks
                ]
            )
            mask = mask.reshape(-1, 1, resolution // 8, resolution // 8)

            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample

            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = (
                    unet.parameters()
                )
                accelerator.clip_grad_norm_(params_to_clip, 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if global_step % 500 == 0:
                if accelerator.is_main_process:
                    save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)
        accelerator.log(logs, step=global_step)

        if global_step >= max_train_steps:
            break

    accelerator.wait_for_everyone()

if accelerator.is_main_process:
    pipeline = StableDiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path,
        unet=accelerator.unwrap_model(unet),
        text_encoder=accelerator.unwrap_model(text_encoder),
    )
    pipeline.save_pretrained(output_dir)

accelerator.end_training()