#Imports

In [None]:
!pip install diffusers accelerate transformers bitsandbytes einops xformers==0.0.25 nvidia-cutlass huggingface_hub torchmetrics torch-fidelity
!pip install carvekit --no-deps

In [None]:
import argparse
import copy
import gc
import hashlib
import importlib
import itertools
import logging
import math
import os
import shutil
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
from einops import rearrange
import random

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
import torchvision.models
import torchvision.models.segmentation as segmentation
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, model_info, upload_folder
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from transformers import CLIPVisionModel, CLIPTextModel, CLIPProcessor
from transformers import ViTImageProcessor, ViTModel
import bitsandbytes as bnb
import xformers


import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
#from diffusers.utils.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.embeddings import get_timestep_embedding

from carvekit.api.high import HiInterface

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image import StructuralSimilarityIndexMeasure


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from google.colab import drive
drive.mount('/content/drive')

#Credits and Acknowledgements

Main structure of this training loop is inspired by the official huggingface implementation of dreambooth https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py

#Params (Please fill these params)

In [None]:
MODEL_PATH = "runwayml/stable-diffusion-v1-5"  # @param {"type":"string"}
INSTANCE_DIR = "/content/drive/MyDrive/Hypnos/RatanChair1" # @param {"type":"string"}
PRESERVE_DIR = "/content/drive/MyDrive/Hypnos/preserve_images_chair" # @param {"type":"string"}
LATENT_DISC_DIR = "/content/drive/Hypnos/hypnos_ld.pt" # @param {"type":"string"}
OUTPUT_DIR = "/content/Hypnos-Output" # @param {"type":"string"}
LOGGING_DIR = "/content/Hypnos-logging" # @param {"type":"string"}
GRADIENT_ACCUMULATION_STEPS = 0 # @param {"type":"integer"}
NUM_PRESERVE_IMAGES = 200 # @param {"type":"integer"}
CLASS_NAME = "chair" # @param {"type":"string"}

LR = 2e-6 # @param {"type":"number"}
BATCH_SIZE = 1 # @param {"type":"integer"}
WARMUP_STEPS = 0 # @param {"type":"integer"}
TRAIN_STEPS = 800 # @param {"type":"integer"}
GRADIENT_ACCUM = 1 # @param {"type":"integer"}

CHANGE_BG_RATIO = 0.66 # @param {"type":"slider","min":0,"max":1,"step":0.01}
RESIZED_RATIO = 0.15 # @param {"type":"slider","min":0,"max":1,"step":0.01}
RESIZED_RATIO *= CHANGE_BG_RATIO
STD_DEV = 0.8 # @param {"type":"slider","min":0,"max":5,"step":0.05}

generate_preservation = True # @param {"type":"boolean","placeholder":"Generate Preservation Images"}
train_latent_disc = True # @param {"type":"boolean","placeholder":"Train Latent Discriminator"}

In [None]:
#Other params
torch_dtype = torch.float32
REVISION = None
INSTANCE_PROMPT = f"a photo of a sks {CLASS_NAME}"
PRESERVE_PROMPT = f"a photo of a {CLASS_NAME}"

LAYER_WEIGHT = {
    '2': 0.35,
    '3': 0.4,
    '4': 0.25
} #perceptual loss composition

#Utils

In [None]:
class PromptDataset(Dataset):
  def __init__(self, prompt, num_samples):
    self.prompt = prompt
    self.num_samples = num_samples

  def __len__(self):
        return self.num_samples

  def __getitem__(self, index):
      example = {}
      example["prompt"] = self.prompt
      example["index"] = index
      return example

In [None]:
logging_dir = Path(OUTPUT_DIR, LOGGING_DIR)
accelerator_project_config = ProjectConfiguration(project_dir=OUTPUT_DIR, logging_dir=logging_dir)

accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    project_config=accelerator_project_config,
    mixed_precision="fp16"
)

In [None]:
def collate_fn(examples):
    has_attention_mask = "instance_attention_mask" in examples[0]

    no_bg = [example["no_background"] for example in examples]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_image"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    input_ids += [example["preserve_prompt_ids"] for example in examples]
    pixel_values += [example["preserve_image"] for example in examples]

    if has_attention_mask:
        attention_mask += [example["preserve_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
        "no_bg": no_bg
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch

In [None]:
def tokenize_prompt(prompt, tokenizer, max_length=None):
  max_length = max_length if max_length is not None else tokenizer.model_max_length
  return tokenizer(
      prompt,
      truncation=True,
      padding="max_length",
      max_length=max_length,
      return_tensors="pt"
  )

def encode_prompt(text_encoder, input_ids, attention_mask):
  input_ids.to(text_encoder.device)
  attention_mask.to(text_encoder.device)
  return text_encoder(input_ids, attention_mask)[0]

def get_z0(z, pred_noise, t, scheduler, device):
  alpha_cumprod = torch.Tensor([scheduler.alphas_cumprod[a] for a in t]).view(-1, 1, 1, 1).to(device)
  numerator = z - (1-alpha_cumprod).sqrt() * pred_noise
  return numerator / alpha_cumprod.sqrt()



std_dev = STD_DEV
sqrt_2_pi = (torch.tensor(math.pi) * 2 * std_dev).sqrt().to(device)

In [None]:
def get_features(image, model, layers):
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[name] = x
        if name == layers[-1]: break
    return features

def tracerb7_wrapper(interface):
  def inner(images):
    res_tensor = None
    for im in images:
      with torch.no_grad():
        im = transforms.ToPILImage(mode="RGB")(im)
      res = torch.Tensor(interface([im])[0].getdata()).reshape(*im.size, 4)
      res = rearrange(res, "h w c -> c h w").unsqueeze(0)
      if res_tensor == None:
        res_tensor = res
      else:
        res_tensor = torch.cat((res_tensor, res), 0)
    return res_tensor
  return inner


def get_object(model):
  preprocess = transforms.Compose([
    transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
  ])
  resize = transforms.Compose([
    transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
  ])
  half_resize = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
  ])
  colors = {
      "black":  torch.tensor([0, 0, 0]).to(device),
      "white":  torch.tensor([1, 1, 1]).to(device),
      "red":  torch.tensor([1, 0, 0]).to(device),
      "green":  torch.tensor([0, 1, 0]).to(device),
      "blue":  torch.tensor([0, 0, 1]).to(device),
      "cyan":  torch.tensor([0, 1, 1]).to(device),
      "magenta":  torch.tensor([1, 0, 1]).to(device),
      "yellow":  torch.tensor([1, 1, 0]).to(device),
      "purple":  torch.tensor([.5, 0, 1]).to(device),
      "pink":  torch.tensor([1, 0, .5]).to(device),
      "brown":  torch.tensor([.25, .1, .05]).to(device),
      "gray":  torch.tensor([.5, .5, .5]).to(device),
  }
  def inner(images, color="black", down=False, negative=False, maximize=False):
    if type(color) == str:
      color = colors[color]
    in_images = preprocess(0.5*images + 0.5).to(device)
    output = model(in_images)[:, 3:, :, :]
    mask = torch.where(output == 0, torch.tensor(0.0), torch.tensor(1.0)).to(device)

    if down:
      empty = (torch.zeros(images.shape) - 0.5).to(device)
      empty[:, :, 127:383, 127:383] = half_resize(images)
      images = empty

      empty = torch.zeros(mask.shape).to(device)
      empty[:, :, 127:383, 127:383] = half_resize(mask)
      mask = empty

    colored_bg = torch.where(rearrange(mask, "b c h w -> b h w c") == 0, color, colors["black"])
    colored_bg = rearrange(colored_bg, "b h w c -> b c h w").to(device)
    unorm_image = (0.5*resize(images) + 0.5).to(device)
    if negative:
      unorm_image = 1 - unorm_image
    return transforms.Normalize([0.5], [0.5])((unorm_image * mask) + colored_bg)
  return inner

def foreground_perceptual_loss(classifier, segmentator, layer_weight):
  get_ob = get_object(segmentator)
  norm = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])
  def preprocess(image):
    return norm(image/2 + 0.5)
  def inner(real_images, fake_images):
    """Calculate the perceptual loss between real and fake images."""
    layers = list(layer_weight.keys())

    back_color = (torch.rand(size=(1,)) / 1.25).tile(3).to(device)

    #real_images, fake_images = get_ob(real_images, color=back_color), get_ob(fake_images, color=back_color)
    real_features = get_features(preprocess(real_images), classifier, layers)
    fake_features = get_features(preprocess(fake_images), classifier, layers)

    loss = 0.0
    for name in layers:
      loss += layer_weight[name] * torch.mean((real_features[name] - fake_features[name]) ** 2)

    return loss
  return inner, get_ob

In [None]:
#make dataset
class DreamDataset(Dataset):
  def __init__(self, instance_path, instance_prompt, preserve_path, preserve_prompt, tokenizer, size=512, change_bg_ratio=0.75, resized_ratio=8/15):
    self.preserve_path = Path(preserve_path)
    self.preserve_path.mkdir(parents=True, exist_ok=True)
    self.preserve_images_path = list(self.preserve_path.iterdir())
    self.num_preserve_images = len(self.preserve_images_path)
    self.preserve_prompt = preserve_prompt

    self.size = size
    self.tokenizer = tokenizer
    self.instance_path = Path(instance_path)
    self.images_path = list(self.instance_path.iterdir())
    self.instance_prompt = instance_prompt
    self.instance_images_path = list(self.instance_path.iterdir())
    self.num_instance_images = len(self.instance_images_path)
    self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    self.colors = [
        "black",
        "white",
        "red",
        "green",
        "blue",
        "cyan",
        "magenta",
        "yellow",
        "purple",
        "pink",
        "brown"
    ]
    self.normal_ratio = 1.0 - change_bg_ratio
    self.resize_thresh = 1.0 - (change_bg_ratio * resized_ratio)

  def __len__(self):
    return max(self.num_instance_images, self.num_preserve_images)

  def __getitem__(self, index):
    items = {}
    instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
    instance_image = exif_transpose(instance_image)

    if not instance_image.mode =="RGB":
      instance_image = instance_image.convert("RGB")
    items["instance_image"] = self.image_transforms(instance_image)

    self.no_background = torch.rand((1,))[0] > self.normal_ratio
    items["no_background"] = self.no_background
    if self.no_background:
      color = random.choice(self.colors)
      self.instance_prompt += f", {color} background"
      items["instance_image"] = get_ob(items["instance_image"].unsqueeze(0), color, torch.rand((1,))[0] > self.resize_thresh)[0].cpu()
    else:
      self.instance_prompt += ", vfx background"

    instance_token = tokenize_prompt(self.instance_prompt, self.tokenizer)
    items['instance_prompt_ids'] = instance_token.input_ids
    items['instance_attention_mask'] = instance_token.attention_mask

    if self.preserve_path:
      preserve_image = Image.open(self.preserve_images_path[index % self.num_preserve_images])
      preserve_image = exif_transpose(preserve_image)
      if not preserve_image.mode =="RGB":
        preserve_image = preserve_image.convert("RGB")
      items["preserve_image"] = self.image_transforms(preserve_image)

      preserve_token = tokenize_prompt(self.preserve_prompt, self.tokenizer)
      items['preserve_prompt_ids'] = preserve_token.input_ids
      items['preserve_attention_mask'] = preserve_token.attention_mask

    return items


In [None]:
class LDreamDataset(Dataset):
  def __init__(self,instance_path, preserve_path, size=512):
    self.instance_path = Path(instance_path)
    self.instance_images_path = list(self.instance_path.iterdir())
    self.num_instance_images = len(self.instance_images_path)

    self.preserve_path = Path(preserve_path)
    self.preserve_images_path = list(self.preserve_path.iterdir())
    self.num_preserve_images = len(self.preserve_images_path)
    self.image_transforms = transforms.Compose(
          [
              transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
              transforms.CenterCrop(size),
              transforms.ToTensor(),
              transforms.Normalize([0.5], [0.5]),
          ]
      )
    self.colors = [
        "black",
        "white",
        "red",
        "green",
        "blue",
        "cyan",
        "magenta",
        "yellow",
        "purple",
        "pink",
        "brown"
    ]

  def __len__(self):
    return max(self.num_instance_images, self.num_preserve_images)*2

  def __getitem__(self, index):
    mode = random.uniform(0, 1)
    if index%2:
      instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
      instance_image = exif_transpose(instance_image)

      if not instance_image.mode =="RGB":
        instance_image = instance_image.convert("RGB")
      instance_image = self.image_transforms(instance_image)
      if mode > 0.6:
        instance_image =  get_ob(instance_image.unsqueeze(0), random.choice(self.colors), mode > 0.9)[0].cpu()
      return instance_image, torch.Tensor([1])
    else:
      if mode < 0.35:
        preserve_image = Image.open(self.preserve_images_path[index % self.num_preserve_images])
        preserve_image = exif_transpose(preserve_image)
        if not preserve_image.mode =="RGB":
          preserve_image = preserve_image.convert("RGB")
        preserve_image = self.image_transforms(preserve_image)
      else:
        is_negative = random.uniform(0, 1) < 0.25
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        instance_image = exif_transpose(instance_image)

        if not instance_image.mode =="RGB":
          instance_image = instance_image.convert("RGB")
        instance_image = self.image_transforms(instance_image)
        if mode > 0.6:
          instance_image =  get_ob(instance_image.unsqueeze(0), random.choice(self.colors), False, is_negative)[0].cpu()
          if is_negative:
            preserve_image = instance_image
          else:
            preserve_image = transforms.Normalize([0.5], [0.5])((0.5*instance_image +0.5) - (0.5*get_ob(instance_image.unsqueeze(0), "black", False)[0].cpu() + 0.5))
        else:
          if is_negative:
            background = (0.5*instance_image +0.5) - (0.5*get_ob(instance_image.unsqueeze(0), "black", False)[0].cpu() + 0.5)
            instance_image = (0.5*get_ob(instance_image.unsqueeze(0), "black", False, True)[0].cpu() + 0.5)
            preserve_image = transforms.Normalize([0.5], [0.5])(instance_image + background)
          else:
            preserve_image = transforms.Normalize([0.5], [0.5])((0.5*instance_image +0.5) - (0.5*get_ob(instance_image.unsqueeze(0), "black", False)[0].cpu() + 0.5))
      return preserve_image, torch.Tensor([0])



In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, in_channel, out_channel):
    super().__init__()
    self.norm = nn.LayerNorm(in_channel)
    self.mha = nn.MultiheadAttention(in_channel, 4, batch_first=True)
    self.mlp = nn.Sequential(
        nn.LayerNorm(in_channel),
        nn.Linear(in_channel, out_channel),
        nn.SiLU()
    )

  def forward(self, x):
    xn = self.norm(x)
    xm, aw = self.mha(xn, xn, xn, average_attn_weights=False)
    x = x + xm
    return self.mlp(x), aw

class ExtractPatch(nn.Module):
  def __init__(self, dim, patch_dim, in_channel):
    super().__init__()
    self.dim = dim
    self.patch_dim = patch_dim
    self.lin = nn.Linear(patch_dim*patch_dim*in_channel, patch_dim*patch_dim*in_channel)
    self.cls = nn.Parameter(torch.randn(1, 1, 128))
    self.out = None

  def forward(self, x):
    self.out = None
    position = torch.Tensor([0.])
    for i in range(self.dim//self.patch_dim):
      for j in range(self.dim//self.patch_dim):
        patch  = rearrange(x[:, :, i*self.patch_dim:i*self.patch_dim+4, j*self.patch_dim:j*self.patch_dim+4], "b c h w -> b (c h w)")
        patch = self.lin(patch).unsqueeze(1)
        pos_encode = torch.tile(get_timestep_embedding(position, 128), (patch.shape[0], 1)).unsqueeze(1).to(device)
        patch = patch + pos_encode
        if self.out is None:
          self.out = patch
        else:
          self.out = torch.cat([self.out, patch], axis=1)
        position = position + 1
    cls = torch.tile(self.cls, (patch.shape[0], 1, 1))
    self.out = torch.cat([cls, self.out], axis=1)
    return self.out



class MergePatch(nn.Module):
  def __init__(self, dim, patch_dim, in_channel):
    super().__init__()
    self.dim = dim
    self.patch_dim = patch_dim
    self.num_patch_per_col = dim//patch_dim
    self.in_channel = in_channel
    self.column = None
    self.out = None

  def forward(self, x):
    self.column = None
    self.out = None
    c, x = x[:, 0, :], x[:, 1:, :]
    for i in range(x.shape[1]):
      patch = rearrange(x[:, i, :], "b (c h w) -> b c h w", c=self.in_channel, h=self.patch_dim, w=self.patch_dim)
      if self.column is None:
          self.column = patch
      else:
        self.column = torch.cat([self.column, patch], axis=3)
      if (i+1) % self.num_patch_per_col == 0:
        if self.out is None:
          self.out = self.column
        else:
          self.out = torch.cat([self.out, self.column], axis=2)
        self.column = None
    return c, self.out


class LatentDiscOld(nn.Module):
  def __init__(self):
    super().__init__()
    self.downconv = nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1)
    self.conv1 = nn.Sequential(
        nn.Conv2d(8, 8, kernel_size=3, padding=1),
        nn.LayerNorm(32),
        nn.SiLU()
    )
    self.extract = ExtractPatch(32, 4, 8)
    self.block1 = TransformerBlock(128, 64)
    self.block2 = TransformerBlock(64, 32)
    self.block3 = TransformerBlock(32, 16)
    self.head = nn.Sequential(
        nn.Linear(1024, 128),
        nn.SiLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.downconv(x)
    x = self.conv1(x)
    x = self.extract(x)
    x, _ = self.block1(x)
    x, _ = self.block2(x)
    x, _ = self.block3(x)
    x = rearrange(x, "b s e -> b (s e)")
    return self.head(x)

class LatentDisc(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(4, 8, kernel_size=3, padding=1),
        nn.LayerNorm(64),
        nn.SiLU(),
        nn.Conv2d(8, 4, kernel_size=3, padding=1),
        nn.LayerNorm(64),
        nn.SiLU()
    )
    self.extract = ExtractPatch(64, 4, 8)
    self.block1 = TransformerBlock(128, 128)
    self.block2 = TransformerBlock(128, 128)
    self.block3 = TransformerBlock(128, 128)
    self.head = nn.Sequential(
        nn.Linear(128, 32),
        nn.SiLU(),
        nn.Linear(32, 16),
        nn.SiLU(),
        nn.Linear(16, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    xc = self.conv1(x)
    x = self.extract(torch.cat([x, xc], axis=1))
    x,_ = self.block1(x)
    x,_ = self.block2(x)
    x,_ = self.block3(x)
    return self.head(x[:, 0, :])

class LatentDiscWithConv(nn.Module):
  def __init__(self):
    super().__init__()
    self.downconv = nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1)
    self.conv1 = nn.Sequential(
        nn.Conv2d(8, 8, kernel_size=3, padding=1),
        nn.LayerNorm(32),
        nn.SiLU()
    )
    self.extract = ExtractPatch(32, 4, 8)
    self.block1 = TransformerBlock(128, 128)
    self.block2 = TransformerBlock(128, 128)
    self.block3 = TransformerBlock(128, 128)
    self.merge = MergePatch(32, 4, 8),
    self.conv2 = nn.Sequential(

    )
    self.head = nn.Sequential(
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Linear(64, 32),
        nn.SiLU(),
        nn.Linear(32, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.downconv(x)
    x = self.conv1(x)
    x = self.extract(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    _, x = self.merge(x)
    return self.head(x)

class LatentBGRemover(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(4, 8, kernel_size=3, padding=1)
    self.downconv = nn.Sequential(
        nn.Conv2d(8, 8, kernel_size=4, stride=2, padding=1),
        nn.LayerNorm(32),
        nn.SiLU()
    )
    self.upconv = nn.Sequential(
        nn.LayerNorm(32),
        nn.ConvTranspose2d(16, 4, kernel_size=4, stride=2, padding=1, bias=False),
    )
    self.extract = ExtractPatch(32, 4, 8)
    self.block1 = TransformerBlock(128, 64)
    self.block2 = TransformerBlock(64, 64)
    self.block3 = TransformerBlock(64, 64)
    self.block4 = TransformerBlock(64, 128)
    self.merge = MergePatch(32, 4, 8)

  def forward(self, x):
    x0 = self.conv1(x)
    x = self.dowwnconv(x0)
    x1 = self.extract(x)
    x2 = self.block1(x1)
    x3 = self.block2(x2)

    x = self.block3(x3)
    x = self.block4(x + x2)
    x = self.merge(x)
    return self.upconv(torch.cat([x, x0], axis=1))

#Make Preservation Image

In [None]:
if not generate_preservation:
  print("Preservation Images Generation turned off")
  return

pipeline = DiffusionPipeline.from_pretrained(
                MODEL_PATH,
                torch_dtype=torch.float16,
                safety_checker=None,
                revision=REVISION,
            )
pipeline.set_progress_bar_config(disable=True)

preserve_dir = Path(PRESERVE_DIR)
if not preserve_dir.exists(): preserve_dir.mkdir(parents=True)
existing_num_preserve = len(list(preserve_dir.iterdir()))
num_generate = max(NUM_PRESERVE_IMAGES - existing_num_preserve, 0)

preserve_prompt_ds = PromptDataset(PRESERVE_PROMPT, num_generate)
preserve_prompt_loader = DataLoader(preserve_prompt_ds, batch_size=4)
preserve_prompt_loader = accelerator.prepare(preserve_prompt_loader)

pipeline.to(accelerator.device)

for batch in tqdm(preserve_prompt_loader, desc="Generating Images", disable=not accelerator.is_local_main_process):
  images = pipeline(batch["prompt"]).images

  for i, image in enumerate(images):
    hashed = hashlib.sha1(image.tobytes()).hexdigest()
    image_index = batch["index"][i] + existing_num_preserve
    image_filename = preserve_dir / f"{image_index}_{hashed}.jpg"
    image.save(image_filename)

del pipeline

#Initialization

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
            MODEL_PATH,
            subfolder="tokenizer",
            revision=REVISION,
            use_fast=False,
      )

noise_scheduler = DDPMScheduler.from_pretrained(MODEL_PATH, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(MODEL_PATH, subfolder="text_encoder", revision=REVISION)
vae = AutoencoderKL.from_pretrained(
            MODEL_PATH, subfolder="vae", revision=REVISION
      )
unet = UNet2DConditionModel.from_pretrained(
        MODEL_PATH, subfolder="unet", revision=REVISION
      )

In [None]:
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()))

optimizer = bnb.optim.AdamW8bit(
    params_to_optimize,
    lr = LR,
    betas=(.9, .999),
    weight_decay=1e-2,
    eps=1e-8
)


In [None]:
enetb1 = torchvision.models.efficientnet_b1(pretrained=True).features.to(device)
enetb1.eval()

# Check doc strings for more information
interface = HiInterface(object_type="object",  # Can be "object" or "hairs-like".
                        batch_size_seg=5,
                        batch_size_matting=1,
                        device='cuda' if torch.cuda.is_available() else 'cpu',
                        seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
                        matting_mask_size=2048,
                        trimap_prob_threshold=231,
                        trimap_dilation=30,
                        trimap_erosion_iters=5,
                        fp16=True)

In [None]:
tracerb7 = tracerb7_wrapper(interface)

foreground_perceptual, get_ob = foreground_perceptual_loss(
    classifier = enetb1,
    segmentator = tracerb7,
    layer_weight = LAYER_WEIGHT
)

In [None]:
train_ds = DreamDataset(
    instance_path=INSTANCE_DIR,
    instance_prompt=INSTANCE_PROMPT,
    preserve_path=PRESERVE_DIR,
    preserve_prompt=PRESERVE_PROMPT,
    tokenizer=tokenizer,
    change_bg_ratio=CHANGE_BG_RATIO,
    resized_ratio=0
)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda e: collate_fn(e)
)

lr_scheduler = get_scheduler(
        "constant",
        optimizer=optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=TRAIN_STEPS,
        num_cycles=1,
    )

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

unet, text_encoder, optimizer, train_loader, lr_scheduler = accelerator.prepare(
            unet, text_encoder, optimizer, train_loader, lr_scheduler
        )

vae.to(accelerator.device, dtype=torch.float16)
text_encoder.to(accelerator.device, dtype=torch.float32)

num_updating = math.ceil(len(train_loader) / GRADIENT_ACCUM)
num_epoch = math.ceil(TRAIN_STEPS / num_updating)
step_counter = 0

#Latent Discriminator Training

In [None]:
if not train_latent_disc:
  ld_mod = LatentDisc().to(device)
  ld_mod.load_state_dict(torch.load(LATENT_DISC_DIR))
  return

latent_ds = LDreamDataset(
    instance_path=INSTANCE_DIR,
    preserve_path=PRESERVE_DIR,
)

latent_loader = DataLoader(
    latent_ds,
    batch_size=4,
    shuffle=True,
)

ld_mod = LatentDisc().to(device)

optimizer_ld = torch.optim.AdamW(
    ld_mod.parameters(),
    lr = 1e-4,
    betas=(.9, .999),
    weight_decay=1e-2,
    eps=1e-8
)

progress_bar_ld = tqdm(
    range(len(latent_ds)//4*6),
    initial=0,
    desc="Steps",
)
losses = []

for epoch in range(6):
  if epoch >= 4:
    for g in optimizer_ld.param_groups:
      g['lr'] = 1e-5
  for step, (image, label) in enumerate(latent_loader):
    image = image.to(dtype=torch.float16).to(device)
    latent = vae.encode(image).latent_dist.sample()
    latent *= vae.config.scaling_factor

    pred = ld_mod(latent.to(dtype=torch.float32))
    loss = nn.MSELoss()(label.to(device), pred)
    losses.append(float(loss.mean()))

    loss.backward()
    optimizer_ld.step()
    optimizer_ld.zero_grad()
    progress_bar_ld.set_description(f"loss : {float(loss.mean())}")
    progress_bar_ld.update(1)

torch.save(ld_mod.state_dict(), LATENT_DISC_DIR)

#Training

In [None]:
ld_mod.requires_grad_(True)
optimizer_ld = torch.optim.AdamW(
    ld_mod.parameters(),
    lr = 1e-5,
    betas=(.9, .999),
    weight_decay=1e-2,
    eps=1e-8
)

In [None]:
progress_bar = tqdm(
    range(TRAIN_STEPS),
    initial=0,
    desc="Steps",
    disable=not accelerator.is_local_main_process
)

losses = []

for epoch in range(1, num_epoch+1):
  unet.train()
  text_encoder.train()

  for step, batch in enumerate(train_loader):
    with torch.no_grad():
      latent = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
      latent *= vae.config.scaling_factor

      # sample noise and timestep
      noise = torch.randn_like(latent)

      b_size, channel, height, width = latent.shape
      t = torch.randint(
          0,
          noise_scheduler.config.num_train_timesteps,
          (b_size,),
          device=latent.device
      ).long()

      noisy_latent = noise_scheduler.add_noise(latent, noise, t).to(dtype=torch.float32)

    with accelerator.accumulate(unet):

      text_embed = encode_prompt(
          text_encoder,
          batch["input_ids"],
          batch["attention_mask"],
      ).to(dtype=torch.float32)

      if accelerator.unwrap_model(unet).config.in_channels == channel * 2:
        noisy_latent = torch.cat([noisy_latent, noisy_latent], dim=1).to(dtype=torch.float32)

      #predict
      pred_unet = unet(
          noisy_latent, t, text_embed, class_labels=None
      ).sample

      if pred_unet.shape[1] == 6:
        pred_unet, _ = torch.chunk(pred_unet, 2, dim=1)

      pred_unet_i, pred_unet_p = torch.chunk(pred_unet, 2, dim=0)
      noisy_latent_i, noisy_latent_p = torch.chunk(noisy_latent, 2, dim=0)
      noise_i, noise_p = torch.chunk(noise, 2, dim=0)
      t_i, t_p = torch.chunk(t, 2, dim=0)

      z0_i = get_z0(noisy_latent_i.to(device), pred_unet_i.to(device), t_i.to(device), noise_scheduler, device)

      #Loss
      preserve_loss = F.mse_loss(
          pred_unet_p.float(), noise_p.float(),
          reduction="mean"
      )

      instance_loss = sqrt_2_pi * ((F.mse_loss(
          pred_unet_i.float(), noise_i.float(),
          reduction="mean"
      )*(0.5 / std_dev**2)).exp() - 1)

      fp_loss = 0
      if step_counter <= 500:
        fp_loss = foreground_perceptual(
            batch["pixel_values"][0].unsqueeze(0)
            , vae.decode(z0_i.to(dtype=torch.float16) / vae.config.scaling_factor).sample.to(dtype=torch.float32)
        )

      lp_loss = 0
      lp_loss = F.mse_loss(
        torch.Tensor([[1.]]).tile(1, b_size//2).to(device).float(), ld_mod(z0_i.float()),
      )

      loss = instance_loss + preserve_loss + 3e-3*fp_loss + .5*lp_loss
      losses.append((t, instance_loss, preserve_loss, 3e-3*fp_loss, .5*lp_loss))

      #backprop
      accelerator.backward(loss)
      if accelerator.sync_gradients:
        params_to_clip = (
            itertools.chain(unet.parameters(), text_encoder.parameters())
        )
        accelerator.clip_grad_norm_(params_to_clip, 1.0)

      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()

    #Discriminator
    with torch.no_grad():
      text_embed = encode_prompt(
          text_encoder,
          batch["input_ids"],
          batch["attention_mask"],
      ).to(dtype=torch.float32)

      pred_unet = unet(
          noisy_latent, t, text_embed, class_labels=None
      ).sample

      pred_unet_i, pred_unet_p = torch.chunk(pred_unet, 2, dim=0)
      noisy_latent_i, noisy_latent_p = torch.chunk(noisy_latent, 2, dim=0)
      t_i, t_p = torch.chunk(t, 2, dim=0)

      z0_i = get_z0(noisy_latent_i.to(device), pred_unet_i.to(device), t_i.to(device), noise_scheduler, device)

    latent_ld = torch.cat([latent[:1, :, :, :], z0_i], dim=0)
    pred_ld = ld_mod(latent_ld.to(dtype=torch.float32))
    loss_ld = nn.MSELoss()(torch.Tensor([[1.], [0.]]).float().to(device), pred_ld)

    loss_ld.backward()
    optimizer_ld.step()
    optimizer_ld.zero_grad()

    if accelerator.sync_gradients:
      step_counter += 1
      if step_counter % GRADIENT_ACCUM == 0:
        progress_bar.set_description(f"loss : {float(loss)}, ld_loss : {float(loss_ld)}")
        progress_bar.update(1)

In [None]:
#Model saving
accelerator.wait_for_everyone()
if accelerator.is_main_process:
  pipeline = DiffusionPipeline.from_pretrained(
      MODEL_PATH,
      unet=accelerator.unwrap_model(unet),
      text_encoder=accelerator.unwrap_model(text_encoder),
      revision=REVISION
  )

  scheduler_args = {}

  if "variance_type" in pipeline.scheduler.config:
      variance_type = pipeline.scheduler.config.variance_type
      if variance_type in ["learned", "learned_range"]:
          variance_type = "fixed_small"
      scheduler_args["variance_type"] = variance_type
  pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)

  pipeline.save_pretrained(OUTPUT_DIR)

accelerator.end_training()

#Load Trained Model

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, torch_dtype=torch.float16, use_safetensors=True, safety_checker=None).to(device)

#Evaluation Initialization

In [None]:
class BaseMetric():
  def __init__(self):
    self.count = torch.Tensor([0.]).to(device)
    self.score = torch.Tensor([0.]).to(device)
    self.std = torch.Tensor([0.]).to(device)
    self.score_list = torch.Tensor([]).to(device)

  def compute(self):
    raise NotImplementedError

  def update(self, instance_path, image):
    s = self.compute(instance_path, image)
    self.score = self.score * (self.count/(self.count+1.)) + s * (1./(self.count+1.))
    self.count = self.count + 1
    self.score_list = torch.cat([self.score_list, s.ravel()], axis=0)
    self.std = ((self.score_list - self.score) ** 2).mean()
    return self

  def reset(self):
    self.count = torch.Tensor([0.]).to(device)
    self.score = torch.Tensor([0.]).to(device)
    self.std = torch.Tensor([0.]).to(device)
    self.score_list = torch.Tensor([]).to(device)
    return self

  def print(self):
    print(f"{self.__class__.__name__} : {self.score} ± {self.std}")

class DINO_metric(BaseMetric):
  def __init__(self, processor, model):
    super().__init__()
    self.processor = processor
    self.model = model

  def compute(self, instance_path, image):
    with torch.no_grad():
      inputs = self.processor(images=image, return_tensors="pt").to(device)
      outputs = self.model(**inputs)
      image_emb = outputs.last_hidden_state[0]

      buffer = None
      for ins_path in list(Path(instance_path).iterdir()):
        instance = exif_transpose(Image.open(ins_path))
        inputs = self.processor(images=instance, return_tensors="pt").to(device)
        outputs = self.model(**inputs)
        ins_emb = outputs.last_hidden_state[0]
        cos = torch.abs(F.cosine_similarity(image_emb, ins_emb).unsqueeze(0)[:, 0])
        if buffer == None:
          buffer = cos
        else:
          buffer = torch.cat([buffer, cos], axis=0)
    return torch.mean(buffer)


class CLIP_I_metric(BaseMetric):
  def __init__(self, processor, model):
    super().__init__()
    self.processor = processor
    self.model = model

  def compute(self, instance_path, image):
    with torch.no_grad():
      inputs = self.processor(images=image, return_tensors="pt")["pixel_values"].to(device)
      outputs = self.model(inputs)
      image_emb = outputs[1]

      buffer = None
      for ins_path in list(Path(instance_path).iterdir()):
        instance = exif_transpose(Image.open(ins_path))
        inputs = self.processor(images=instance, return_tensors="pt")["pixel_values"].to(device)
        outputs = self.model(inputs)
        ins_emb = outputs[1]
        cos = torch.abs(F.cosine_similarity(image_emb, ins_emb).unsqueeze(0)[:, 0])
        if buffer == None:
          buffer = cos
        else:
          buffer = torch.cat([buffer, cos], axis=0)
    return torch.mean(buffer)

class CLIP_T_metric(BaseMetric):
  def __init__(self, processor, text_model, vision_model):
    super().__init__()
    self.processor = processor
    self.text_model = text_model
    self.vision_model = vision_model

  def compute(self, prompt, image):
    with torch.no_grad():
      vision_inputs = self.processor(images=image, return_tensors="pt")["pixel_values"].to(device)
      vision_emb = self.vision_model(vision_inputs)[1]
      id, mask = self.processor(
          text=prompt,
          padding="max_length",
          return_tensors="pt"
      ).values()
      text_emb = self.text_model(id.to(device), mask.to(device))[1]

    return torch.abs(F.cosine_similarity(text_emb, vision_emb))

class FID_metric():
  def __init__(self, instance_path):
    self.metric = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    self.preprocess = transforms.Compose([
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor()
    ])
    for ins_path in list(Path(instance_path).iterdir()):
        instance = Image.open(ins_path)
        instance = exif_transpose(Image.open(ins_path))
        instance = self.preprocess(instance).unsqueeze(0).to(device)
        self.metric.update(instance, real=True)

  def score(self):
    return self.metric.compute()

  def update(self, image, real=False):
    image = self.preprocess(image).unsqueeze(0).to(device)
    self.metric.update(image, real=real)

  def reset(self):
    self.metric.reset()

  def print(self):
    print(f"{self.__class__.__name__} : {self.score()}")

class SSIM_metric(BaseMetric):
  def __init__(self):
    super().__init__()
    self.metric = StructuralSimilarityIndexMeasure().to(device)
    self.preprocess = transforms.Compose([
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor()
    ])

  def compute(self, instance_path, image):
    buffer=None
    image = self.preprocess(image).unsqueeze(0).to(device)
    for ins_path in list(Path(instance_path).iterdir()):
        instance = exif_transpose(Image.open(ins_path))
        instance = self.preprocess(instance).unsqueeze(0).to(device)
        score = torch.Tensor([self.metric(image, instance)])
        if buffer == None:
          buffer = score
        else:
          buffer = torch.cat([buffer, score], axis=0)
    return torch.mean(buffer).to(device)

class PSNR_metric(BaseMetric):
  def __init__(self):
    super().__init__()
    self.metric = PeakSignalNoiseRatio().to(device)
    self.preprocess = transforms.Compose([
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor()
    ])

  def compute(self, instance_path, image):
    buffer=None
    image = self.preprocess(image).unsqueeze(0).to(device)
    for ins_path in list(Path(instance_path).iterdir()):
        instance = exif_transpose(Image.open(ins_path))
        instance = self.preprocess(instance).unsqueeze(0).to(device)
        score = torch.Tensor([self.metric(image, instance)])
        if buffer == None:
          buffer = score
        else:
          buffer = torch.cat([buffer, score], axis=0)
    return torch.mean(buffer).to(device)

class LPIPS_metric(BaseMetric):
  def __init__(self):
    super().__init__()
    self.metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to(device)
    self.preprocess = transforms.Compose([
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor()
    ])

  def compute(self, instance_path, image):
    buffer=None
    image = self.preprocess(image).unsqueeze(0).to(device)
    for ins_path in list(Path(instance_path).iterdir()):
        instance = exif_transpose(Image.open(ins_path))
        instance = self.preprocess(instance).unsqueeze(0).to(device)
        score = torch.Tensor([self.metric(image, instance)])
        if buffer == None:
          buffer = score
        else:
          buffer = torch.cat([buffer, score], axis=0)
    return torch.mean(buffer).to(device)

def make_prompt(im_type, subject, background, style):
  return f"a {im_type} of {subject} {background} {style}"

im_type_list = ["photo", "photo", "photo", "photo", "painting"]
background_list = ["in sahara desert covered in sand",
              "in a megacity with skyscrapers",
              "in bali beach with temples on the back",
              "in forest",
              "in the middle of new york city times square",
              "under the sea with corals",
              "on top of mount everest covered in snow",
              "in a field with cherry blossom trees",
              "on the moon with galaxies background",
              "infront of the eiffel tower",
              "in a cyberpank alley with colorful neon lights"]

style_list = {
    "photo":[", realistic lighting, sharp, vibrant color", "", ", realistic lighting", ", dark and horror", ", bright, cheerful, colorful"],
    "painting": [", van gogh style", ", colorful", ", watercolor", ", da vinci style", ", renaissance painting", ", expressionism"]
}

In [None]:
processor_dino = ViTImageProcessor.from_pretrained('facebook/dino-vits16')
model_dino = ViTModel.from_pretrained('facebook/dino-vits16').to(device)

vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

#Evaluation

In [None]:
#Prompt Invariant
dino = DINO_metric(processor_dino, model_dino)
clip_i = CLIP_I_metric(processor, vision_model)
clip_t = CLIP_T_metric(processor, pipe.text_encoder, vision_model)
fid = FID_metric(INSTANCE_DIR)
ssim = SSIM_metric()
psnr = PSNR_metric()
lpips = LPIPS_metric()

pipe.set_progress_bar_config(disable=True)

for i in tqdm(range(50)):
  prompt = INSTANCE_PROMPT
  image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
  dino.update(INSTANCE_DIR, image)
  clip_i.update(INSTANCE_DIR, image)
  clip_t.update(prompt, image)
  fid.update(image)
  ssim.update(INSTANCE_DIR, image)
  psnr.update(INSTANCE_DIR, image)
  lpips.update(INSTANCE_DIR, image)

dino.print()
clip_i.print()
clip_t.print()
fid.print()
ssim.print()
psnr.print()
lpips.print()

In [None]:
#Prompt Varying

dino = DINO_metric(processor_dino, model_dino)
clip_i = CLIP_I_metric(processor, vision_model)
clip_t = CLIP_T_metric(processor, pipe.text_encoder, vision_model)
fid = FID_metric(INSTANCE_DIR)
ssim = SSIM_metric()
psnr = PSNR_metric()
lpips = LPIPS_metric()

pipe.set_progress_bar_config(disable=True)

for i in tqdm(range(50)):
  im_type = random.choice(im_type_list)
  background = random.choice(background_list)
  style = random.choice(style_list[im_type])
  prompt = make_prompt(im_type, f"sks {CLASS_NAME}", background, style)
  image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
  dino.update(INSTANCE_DIR, image)
  clip_i.update(INSTANCE_DIR, image)
  clip_t.update(prompt, image)
  fid.update(image)
  ssim.update(INSTANCE_DIR, image)
  psnr.update(INSTANCE_DIR, image)
  lpips.update(INSTANCE_DIR, image)

dino.print()
clip_i.print()
clip_t.print()
fid.print()
ssim.print()
psnr.print()
lpips.print()