Skip to content

Commit

Permalink
feat(api): make pipeline stages support multiple images
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 4, 2023
1 parent f718087 commit 3718525
Show file tree
Hide file tree
Showing 22 changed files with 486 additions and 438 deletions.
6 changes: 4 additions & 2 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,19 @@ def __call__(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image] = None,
source: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> Image.Image:
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)

start = monotonic()

# TODO: turn this into stage images
image = source

if source is not None:
Expand Down
74 changes: 37 additions & 37 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional

import numpy as np
import torch
Expand All @@ -22,15 +22,14 @@ def run(
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
Expand Down Expand Up @@ -59,39 +58,40 @@ def run(
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength

if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
outputs = []
for source in sources:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)

rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)

output = result.images[0]
outputs.extend(result.images)

logger.info("final output image size: %sx%s", output.width, output.height)
return output
return outputs
154 changes: 78 additions & 76 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from logging import getLogger
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
from PIL import Image

from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed
from ..diffusers.utils import get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
Expand All @@ -26,7 +26,7 @@ def run(
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
Expand All @@ -36,95 +36,97 @@ def run(
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)

if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")

source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)

if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)

pipe_type = "lpw" if params.lpw() else "inpaint"
_prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("inpaint")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
inversions=inversions,
loras=loras,
)

def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
outputs = []
for source in sources:
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")

source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)

if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)

latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))

return result.images[0]
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)

output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)

logger.info("final output image size: %s", output.size)
return output
return result.images[0]

outputs.append(
process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
)
9 changes: 5 additions & 4 deletions api/onnx_web/chain/blend_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ def run(
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*,
alpha: float,
sources: Optional[List[Image.Image]] = None,
stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using linear interpolation")
) -> List[Image.Image]:
logger.info("blending source images using linear interpolation")

return Image.blend(sources[1], sources[0], alpha)
return [Image.blend(source, stage_source, alpha) for source in sources]
8 changes: 4 additions & 4 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional

from PIL import Image

Expand All @@ -20,13 +20,13 @@ def run(
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
logger.info("blending image using mask")

mult_mask = Image.new("RGBA", stage_mask.size, color="black")
Expand All @@ -37,4 +37,4 @@ def run(
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)

return Image.composite(stage_source, source, mult_mask)
return [Image.composite(stage_source, source, mult_mask) for source in sources]
10 changes: 4 additions & 6 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional

from PIL import Image

Expand All @@ -18,20 +18,18 @@ def run(
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer

source = stage_source or source

upscale = upscale.with_args(**kwargs)

device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source)
return [pipe(source) for source in sources]

0 comments on commit 3718525

Please sign in to comment.