Skip to content

Commit

Permalink
feat(api): collect progress from chain pipelines (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent 27a3fa8 commit d9fc908
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 8 deletions.
26 changes: 24 additions & 2 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from PIL import Image

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
Expand All @@ -30,6 +30,24 @@ def __call__(
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]


class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0

def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step

self.step = step
self.parent(self.get_total(), timestep, latents)

def get_total(self) -> int:
return self.step + self.total


class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
Expand Down Expand Up @@ -57,11 +75,15 @@ def __call__(
server: ServerContext,
params: ImageParams,
source: Image.Image,
callback: ProgressCallback = None,
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] outputs
TODO: handle List[Image] inputs and outputs
"""
if callback is not None:
callback = ChainProgress(callback, start=callback.step)

start = monotonic()
logger.info(
"running pipeline on source image with dimensions %sx%s",
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..utils import ServerContext
Expand All @@ -23,6 +23,7 @@ def blend_img2img(
*,
strength: float,
prompt: Optional[str] = None,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
Expand All @@ -46,6 +47,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
Expand All @@ -57,6 +59,7 @@ def blend_img2img(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=callback,
)

output = result.images[0]
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
Expand All @@ -29,6 +29,7 @@ def blend_inpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
logger.info("upscaling image by expanding borders", expand)
Expand Down Expand Up @@ -83,6 +84,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
Expand All @@ -97,6 +99,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)

return result.images[0]
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..utils import ServerContext
Expand All @@ -22,6 +22,7 @@ def source_txt2img(
*,
size: Size,
prompt: str = None,
callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
Expand Down Expand Up @@ -53,6 +54,7 @@ def source_txt2img(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
Expand All @@ -65,6 +67,7 @@ def source_txt2img(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)

output = result.images[0]
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
Expand All @@ -30,6 +30,7 @@ def upscale_outpaint(
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
Expand Down Expand Up @@ -92,6 +93,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
Expand All @@ -106,6 +108,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)

# once part of the image has been drawn, keep it
Expand Down
4 changes: 3 additions & 1 deletion api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image

from ..device_pool import JobContext
from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
Expand Down Expand Up @@ -67,6 +67,7 @@ def upscale_stable_diffusion(
*,
upscale: UpscaleParams,
prompt: str = None,
callback: ProgressCallback,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
Expand All @@ -80,4 +81,5 @@ def upscale_stable_diffusion(
source,
generator=generator,
num_inference_steps=params.steps,
callback=callback,
).images[0]
5 changes: 4 additions & 1 deletion api/onnx_web/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

logger = getLogger(__name__)

ProgressCallback = Callable[[int, int, Any], None]


class JobContext:
cancel: Value = None
Expand Down Expand Up @@ -51,8 +53,9 @@ def get_device(self) -> DeviceParams:
def get_progress(self) -> int:
return self.progress.value

def get_progress_callback(self) -> Callable[..., None]:
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise Exception("job has been cancelled")
else:
Expand Down

0 comments on commit d9fc908

Please sign in to comment.