Skip to content

Commit

Permalink
Other Acceleration tricks (cloneofsimo#93)
Browse files Browse the repository at this point in the history
* feat : face segmentation mask
  • Loading branch information
cloneofsimo committed Dec 29, 2022
1 parent 6231f01 commit eacf501
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 16 deletions.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .lora import *
from .dataset import *
from .utils import *
52 changes: 45 additions & 7 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,26 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
}

if examples[0].get("mask", None) is not None:
batch["mask"] = torch.stack([example["mask"] for example in examples])

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=2,
)

return train_dataloader


@torch.autocast("cuda")
def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
def loss_step(
batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0
):
latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
Expand All @@ -167,7 +172,7 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):

timesteps = torch.randint(
0,
scheduler.config.num_train_timesteps,
int(scheduler.config.num_train_timesteps * t_mutliplier),
(bsz,),
device=latents.device,
)
Expand All @@ -186,6 +191,31 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")

if batch.get("mask", None) is not None:

mask = (
batch["mask"]
.to(model_pred.device)
.reshape(
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
)
)
# resize to match model_pred
mask = (
F.interpolate(
mask.float(),
size=model_pred.shape[-2:],
mode="nearest",
)
+ 0.1
)

mask = mask / mask.mean()

model_pred = model_pred * mask

target = target * mask

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
return loss

Expand Down Expand Up @@ -273,7 +303,15 @@ def perform_tuning(
for batch in dataloader:
optimizer.zero_grad()

loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype)
loss = loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
weight_dtype,
t_mutliplier=0.8,
)
loss.backward()
torch.nn.utils.clip_grad_norm_(
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
Expand Down Expand Up @@ -322,7 +360,7 @@ def train(
class_data_dir: Optional[str] = None,
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Optional[str] = Literal[None, "object", "style"],
use_template: Literal[None, "object", "style"] = None,
placeholder_tokens: str = "<s>",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: str = "dog",
Expand All @@ -332,7 +370,6 @@ def train(
num_class_images: int = 100,
seed: int = 42,
resolution: int = 512,
center_crop: bool = False,
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
Expand All @@ -350,6 +387,7 @@ def train(
learning_rate_ti: float = 5e-4,
continue_inversion: bool = True,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 100,
Expand Down Expand Up @@ -413,8 +451,8 @@ def train(
class_prompt=class_prompt,
tokenizer=tokenizer,
size=resolution,
center_crop=center_crop,
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
)

train_dataloader = text2img_dataloader(
Expand Down
79 changes: 70 additions & 9 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from torch.utils.data import Dataset

from typing import List, Tuple, Dict, Union, Optional
from PIL import Image
from PIL import Image, ImageFilter
from torchvision import transforms
from pathlib import Path

import cv2
import random
import numpy as np

OBJECT_TEMPLATE = [
"a photo of a {}",
Expand Down Expand Up @@ -90,12 +91,12 @@ def __init__(
class_prompt=None,
size=512,
h_flip=True,
center_crop=False,
color_jitter=False,
resize=True,
use_face_segmentation_condition=False,
blur_amount: int = 70,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

Expand All @@ -121,25 +122,32 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.h_flip = h_flip
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
if resize
else transforms.Lambda(lambda x: x),
transforms.ColorJitter(0.2, 0.1)
transforms.ColorJitter(0.1, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip()
if h_flip
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

self.use_face_segmentation_condition = use_face_segmentation_condition
if self.use_face_segmentation_condition:
import mediapipe as mp

mp_face_detection = mp.solutions.face_detection
self.face_detection = mp_face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
self.blur_amount = blur_amount

def __len__(self):
return self._length

Expand All @@ -163,6 +171,59 @@ def __getitem__(self, index):
for token, value in self.token_map.items():
text = text.replace(token, value)

if self.use_face_segmentation_condition:
image = cv2.imread(
str(self.instance_images_path[index % self.num_instance_images])
)
results = self.face_detection.process(
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)
black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

if results.detections:

for detection in results.detections:

x_min = int(
detection.location_data.relative_bounding_box.xmin
* image.shape[1]
)
y_min = int(
detection.location_data.relative_bounding_box.ymin
* image.shape[0]
)
width = int(
detection.location_data.relative_bounding_box.width
* image.shape[1]
)
height = int(
detection.location_data.relative_bounding_box.height
* image.shape[0]
)

# draw the colored rectangle
black_image[y_min : y_min + height, x_min : x_min + width] = 255

# blur the image
black_image = Image.fromarray(black_image, mode="L").filter(
ImageFilter.GaussianBlur(radius=self.blur_amount)
)
# to tensor
black_image = transforms.ToTensor()(black_image)
# resize as the instance image
black_image = transforms.Resize(
self.size, interpolation=transforms.InterpolationMode.BILINEAR
)(black_image)

example["mask"] = black_image

if self.h_flip and random.random() > 0.5:
hflip = transforms.RandomHorizontalFlip(p=1)

example["instance_images"] = hflip(example["instance_images"])
if self.use_face_segmentation_condition:
example["mask"] = hflip(example["mask"])

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
Expand Down
12 changes: 12 additions & 0 deletions lora_diffusion/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from PIL import Image


def image_grid(_imgs, rows, cols):

w, h = _imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size

for i, img in enumerate(_imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid

0 comments on commit eacf501

Please sign in to comment.