Skip to content

Commit

Permalink
fix(api): accumulate progress from inpaint pipelines (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 12, 2023
1 parent b85c806 commit 034be32
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
8 changes: 4 additions & 4 deletions api/logging.yaml
Expand Up @@ -5,14 +5,14 @@ formatters:
handlers:
console:
class: logging.StreamHandler
level: INFO
level: DEBUG
formatter: simple
stream: ext://sys.stdout
loggers:
'':
level: INFO
level: DEBUG
handlers: [console]
propagate: True
root:
level: INFO
handlers: [console]
level: DEBUG
handlers: [console]
8 changes: 6 additions & 2 deletions api/onnx_web/chain/base.py
Expand Up @@ -47,6 +47,11 @@ def __call__(self, step: int, timestep: int, latents: Any) -> None:
def get_total(self) -> int:
return self.step + self.total

@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)


class ChainPipeline:
"""
Expand Down Expand Up @@ -82,8 +87,7 @@ def __call__(
TODO: handle List[Image] inputs and outputs
"""
if callback is not None:
start = callback.step if hasattr(callback, "step") else 0
callback = ChainProgress(callback, start=start)
callback = ChainProgress.from_progress(callback)

start = monotonic()
logger.info(
Expand Down
21 changes: 19 additions & 2 deletions api/onnx_web/diffusion/run.py
Expand Up @@ -6,6 +6,8 @@
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops

from onnx_web.chain.base import ChainProgress

from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
Expand Down Expand Up @@ -65,7 +67,13 @@ def run_txt2img_pipeline(

image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)

dest = save_image(server, output, image)
Expand Down Expand Up @@ -123,7 +131,13 @@ def run_img2img_pipeline(

image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)

dest = save_image(server, output, image)
Expand Down Expand Up @@ -157,6 +171,9 @@ def run_inpaint_pipeline(
progress = job.get_progress_callback()
stage = StageParams(tile_order=tile_order)

# calling the upscale_outpaint stage directly needs accumulating progress
progress = ChainProgress.from_progress(progress)

image = upscale_outpaint(
job,
server,
Expand Down

0 comments on commit 034be32

Please sign in to comment.