Skip to content

Commit

Permalink
feat(api): implement upscaling and correction as a chain pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 27, 2023
1 parent 76e25ac commit bcaf0f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
19 changes: 11 additions & 8 deletions api/onnx_web/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ class StageParams:

def __init__(
self,
name: Optional[str] = None,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.name = name
self.tile_size = tile_size
self.outscale = outscale

Expand Down Expand Up @@ -48,7 +50,7 @@ class ChainPipeline:

def __init__(
self,
stages: List[PipelineStage],
stages: List[PipelineStage] = [],
):
'''
Create a new pipeline that will run the given stages.
Expand All @@ -70,11 +72,12 @@ def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image)
image = source

for stage_fn, stage_params, stage_kwargs in self.stages:
print('running pipeline stage on result image with dimensions %sx%s' %
image.size)
name = stage_params.label or stage_fn.__name__
print('running pipeline stage %s on result image with dimensions %sx%s' %
(name, image.width, image.height))
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size, tiling stage',
stage_params.tile_size)
print('source image larger than tile size of %s, tiling stage' % (
stage_params.tile_size))

def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, params, tile,
Expand All @@ -85,12 +88,12 @@ def stage_tile(tile: Image.Image) -> Image.Image:
image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, run stage')
print('source image within tile size, running stage')
image = stage_fn(ctx, stage_params, params, image,
**stage_kwargs)

print('finished running pipeline stage, result size: %sx%s' %
image.size)
print('finished running pipeline stage %s, result size: %sx%s' %
(name, image.width, image.height))

print('finished running pipeline, result size: %sx%s' % image.size)
return image
11 changes: 6 additions & 5 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,20 +215,21 @@ def run_upscale_correction(
) -> Image.Image:
print('running upscale pipeline')

chain = ChainPipeline()

if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
chain.append((upscale_resrgan, stage, {'upscale': upscale}))
elif 'stable-diffusion' in upscale.upscale_model:
mini_tile = min(128, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
image = upscale_stable_diffusion(
ctx, stage, params, image, upscale=upscale)
chain.append((upscale_stable_diffusion, stage, {'upscale': upscale}))

if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
chain.append((upscale_gfpgan, stage, {'upscale': upscale}))

return image
return chain(ctx, params, image)

0 comments on commit bcaf0f7

Please sign in to comment.