Skip to content

Commit

Permalink
fix(api): update LPW pipeline (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 28, 2023
1 parent c0ece24 commit 93fcfd1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ This is an incomplete list of new and interesting features, with links to the us
- [hosted on Github Pages](https://ssube.github.io/onnx-web), from your CDN, or locally
- [persists your recent images and progress as you change tabs](docs/user-guide.md#image-history)
- queue up multiple images and retry errors
- translations available for English, French, German, and Spanish (please open an issue for more)
- supports many `diffusers` pipelines
- [txt2img](docs/user-guide.md#txt2img-tab)
- [img2img](docs/user-guide.md#img2img-tab)
Expand Down
35 changes: 17 additions & 18 deletions api/onnx_web/diffusers/lpw_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from typing import Callable, List, Optional, Union

import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTokenizer

import diffusers
import PIL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import deprecate, logging
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from diffusers.utils import logging


try:
Expand Down Expand Up @@ -201,14 +201,14 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
return tokens, weights


def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
Expand Down Expand Up @@ -347,12 +347,14 @@ def get_weighted_text_embeddings(
# pad the length of tokens and weights
bos = pipe.tokenizer.bos_token_id
eos = pipe.tokenizer.eos_token_id
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
Expand All @@ -364,6 +366,7 @@ def get_weighted_text_embeddings(
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
Expand Down Expand Up @@ -408,7 +411,7 @@ def get_weighted_text_embeddings(

def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
Expand All @@ -418,7 +421,7 @@ def preprocess_image(image):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
Expand Down Expand Up @@ -446,7 +449,7 @@ def __init__(
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__(
Expand All @@ -473,7 +476,7 @@ def __init__(
unet: OnnxRuntimeModel,
scheduler: SchedulerMixin,
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__(
vae_encoder=vae_encoder,
Expand Down Expand Up @@ -672,7 +675,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
Expand Down Expand Up @@ -749,10 +752,6 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
Expand Down Expand Up @@ -887,7 +886,7 @@ def text2img(
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
Expand Down Expand Up @@ -978,7 +977,7 @@ def img2img(
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
Expand Down Expand Up @@ -1070,7 +1069,7 @@ def inpaint(
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
callback_steps: int = 1,
**kwargs,
):
r"""
Expand Down

0 comments on commit 93fcfd1

Please sign in to comment.