Skip to content

Commit

Permalink
fix(api): use kwargs for chain stages
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 1, 2023
1 parent 7a73c9f commit 2d10252
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 89 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/__init__.py
Expand Up @@ -17,7 +17,7 @@
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import append_upscale_correction
from .diffusers.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
valid_image,
Expand Down
22 changes: 21 additions & 1 deletion api/onnx_web/chain/base.py
Expand Up @@ -78,11 +78,31 @@ def __init__(

def append(self, stage: PipelineStage):
"""
DEPRECATED: use `stage` instead
Append an additional stage to this pipeline.
"""
if stage is not None:
self.stages.append(stage)

def run(
self,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)

def stage(self, callback: StageCallback, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self

def __call__(
self,
job: WorkerContext,
Expand All @@ -93,7 +113,7 @@ def __call__(
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
DEPRECATED: use `run` instead
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/chain/upscale_highres.py
Expand Up @@ -4,8 +4,8 @@
from PIL import Image

from ..chain.base import ChainPipeline
from ..chain.img2img import blend_img2img
from ..diffusers.upscale import append_upscale_correction
from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import stage_upscale_correction
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
Expand Down Expand Up @@ -45,7 +45,7 @@ def upscale_highres(
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
append_upscale_correction(
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
Expand Down
137 changes: 54 additions & 83 deletions api/onnx_web/diffusers/run.py
Expand Up @@ -24,7 +24,7 @@
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext
from .upscale import append_upscale_correction, split_upscale
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt

logger = getLogger(__name__)
Expand All @@ -42,20 +42,16 @@ def run_txt2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append(
(
source_txt2img,
stage,
{
"size": size,
},
)
chain.stage(
source_txt2img,
stage,
size=size,
)

# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
Expand All @@ -64,22 +60,21 @@ def run_txt2img_pipeline(

# apply highres
for _i in range(highres.iterations):
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
StageParams(
outscale=highres.scale,
),
highres=highres,
upscale=upscale,
overlap=params.overlap,
)

# apply upscaling and correction, after highres
append_upscale_correction(
StageParams(),
stage_upscale_correction(
stage,
params,
upscale=upscale,
upscale=after_upscale,
chain=chain,
)

Expand Down Expand Up @@ -128,20 +123,16 @@ def run_img2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
chain.stage(
blend_img2img,
stage,
strength=strength,
)

# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
Expand All @@ -151,32 +142,24 @@ def run_img2img_pipeline(
# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
chain.stage(
blend_img2img,
stage,
strength=strength,
)

# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)

# apply upscaling and correction, after highres
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
Expand Down Expand Up @@ -237,34 +220,26 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
chain.append(
(
upscale_outpaint,
stage,
{
"border": border,
"stage_mask": mask,
"fill_color": fill_color,
"mask_filter": mask_filter,
"noise_source": noise_source,
},
)
chain.stage(
upscale_outpaint,
stage,
border=border,
stage_mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
)

# apply highres
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)

# apply upscaling and correction
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=upscale,
Expand Down Expand Up @@ -313,27 +288,23 @@ def run_upscale_pipeline(
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)

# apply highres
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
"upscale": upscale,
},
)
chain.stage(
upscale_highres,
stage,
highres=highres,
upscale=upscale,
)

# apply upscaling and correction, after highres
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
Expand Down Expand Up @@ -380,7 +351,7 @@ def run_blend_pipeline(
stage.append((blend_mask, stage, None))

# apply upscaling and correction
append_upscale_correction(
stage_upscale_correction(
stage,
params,
upscale=upscale,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/diffusers/upscale.py
Expand Up @@ -36,7 +36,7 @@ def split_upscale(
)


def append_upscale_correction(
def stage_upscale_correction(
stage: StageParams,
params: ImageParams,
*,
Expand Down

0 comments on commit 2d10252

Please sign in to comment.