Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prefer chain stage parameters over request parameters #178

Merged
merged 15 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from . import logging
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
from .diffusion.load import get_latents_from_seed, load_pipeline
from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline
from .diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .diffusion.stub_scheduler import StubScheduler
from .image import (
expand_image,
mask_filter_gaussian_multiply,
Expand All @@ -17,11 +20,30 @@
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
from .server.upscale import run_upscale_correction
from .utils import (
from .onnx import OnnxNet, OnnxTensor
from .params import (
Border,
ImageParams,
Param,
Point,
Size,
StageParams,
UpscaleParams,
)
from .server import (
DeviceParams,
DevicePoolExecutor,
ModelCache,
ServerContext,
apply_patch_basicsr,
apply_patch_codeformer,
apply_patch_facexlib,
apply_patches,
)
from .upscale import run_upscale_correction
from .utils import (
base_join,
get_and_clamp_float,
get_and_clamp_int,
Expand Down
8 changes: 4 additions & 4 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from ..server import JobContext, ProgressCallback, ServerContext
from ..utils import is_debug
from .utils import process_tile_order

logger = getLogger(__name__)
Expand Down Expand Up @@ -61,12 +61,12 @@ class ChainPipeline:

def __init__(
self,
stages: List[PipelineStage] = [],
stages: List[PipelineStage] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages)
self.stages = list(stages or [])

def append(self, stage: PipelineStage):
"""
Expand Down
28 changes: 14 additions & 14 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from logging import getLogger
from typing import Optional

import numpy as np
import torch
Expand All @@ -8,8 +7,7 @@

from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext
from ..server import JobContext, ProgressCallback, ServerContext

logger = getLogger(__name__)

Expand All @@ -19,15 +17,17 @@ def blend_img2img(
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
strength: float,
prompt: Optional[str] = None,
callback: ProgressCallback = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)

pipe = load_pipeline(
server,
Expand All @@ -41,25 +41,25 @@ def blend_img2img(
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
prompt,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
strength=params.strength,
callback=callback,
)

Expand Down
52 changes: 27 additions & 25 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from ..server import JobContext, ProgressCallback, ServerContext
from ..utils import is_debug
from .utils import process_tile_order

logger = getLogger(__name__)
Expand All @@ -22,46 +22,50 @@ def blend_inpaint(
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
expand: Border,
mask_image: Optional[Image.Image] = None,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)

if mask_image is None:
if stage_mask is None:
# if no mask was provided, keep the full source image
mask_image = Image.new("RGB", source_image.size, "black")
stage_mask = Image.new("RGB", source.size, "black")

source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
mask_image,
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)

if is_debug():
save_image(server, "last-source.png", source_image)
save_image(server, "last-mask.png", mask_image)
save_image(server, "last-noise.png", noise_image)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)

def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile))
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))

if is_debug():
save_image(server, "tile-source.png", image)
save_image(server, "tile-mask.png", mask)
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)

latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
Expand All @@ -81,9 +85,9 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=image,
image=tile_source,
latents=latents,
mask_image=mask,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
Expand All @@ -96,9 +100,9 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=image,
image=tile_source,
latents=latents,
mask_image=mask,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
Expand All @@ -107,9 +111,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):

return result.images[0]

output = process_tile_order(
stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
)
output = process_tile_order(stage.tile_order, source, SizeChart.auto, 1, [outpaint])

logger.info("final output image size", output.size)
logger.info("final output image size: %s", output.size)
return output
16 changes: 8 additions & 8 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from onnx_web.output import save_image

from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from ..server import JobContext, ProgressCallback, ServerContext
from ..utils import is_debug

logger = getLogger(__name__)

Expand All @@ -19,24 +19,24 @@ def blend_mask(
_stage: StageParams,
_params: ImageParams,
*,
resized: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None,
sources: Optional[List[Image.Image]] = None,
stage_mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")

mult_mask = Image.new("RGBA", mask.size, color="black")
mult_mask.alpha_composite(mask)
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L")

if is_debug():
save_image(server, "last-mask.png", mask)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)

resized = [
valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size)
for s in resized
for s in sources
]

return Image.composite(resized[0], resized[1], mult_mask)
9 changes: 6 additions & 3 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from PIL import Image

from ..params import ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
from ..server import JobContext, ServerContext

logger = getLogger(__name__)

Expand All @@ -25,6 +24,10 @@ def correct_codeformer(
# must be within the load function for patch to take effect
from codeformer import CodeFormer

source = stage_source or source

upscale = upscale.with_args(**kwargs)

device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(stage_source or source)
return pipe(source)
14 changes: 9 additions & 5 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from PIL import Image

from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc
from ..server import JobContext, ServerContext
from ..utils import run_gc

logger = getLogger(__name__)

Expand Down Expand Up @@ -50,20 +50,24 @@ def correct_gfpgan(
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source

if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source_image
return source

logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = load_gfpgan(server, stage, upscale, device)

output = np.array(source_image)
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
Expand Down
12 changes: 7 additions & 5 deletions api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext
from ..server import JobContext, ServerContext

logger = getLogger(__name__)

Expand All @@ -15,11 +14,14 @@ def persist_disk(
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
output: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
dest = save_image(server, output, source_image)
source = stage_source or source

dest = save_image(server, output, source)
logger.info("saved image to %s", dest)
return source_image
return source
Loading