Skip to content

Commit

Permalink
fix(api): remove nested tiling in highres
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 2, 2023
1 parent eef055e commit a7be651
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 74 deletions.
2 changes: 2 additions & 0 deletions api/onnx_web/chain/__init__.py
Expand Up @@ -17,6 +17,7 @@
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage

Expand All @@ -39,6 +40,7 @@
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
2 changes: 1 addition & 1 deletion api/onnx_web/chain/base.py
Expand Up @@ -11,7 +11,7 @@
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order
from .tile import process_tile_order

logger = getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_inpaint.py
Expand Up @@ -13,7 +13,7 @@
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order
from .tile import process_tile_order

logger = getLogger(__name__)

Expand Down
58 changes: 58 additions & 0 deletions api/onnx_web/chain/highres.py
@@ -0,0 +1,58 @@
from logging import getLogger
from typing import Optional

from ..chain.base import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams

logger = getLogger(__name__)


def stage_highres(
stage: StageParams,
params: ImageParams,
highres: HighresParams,
upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None,
) -> ChainPipeline:
logger.info("staging highres pipeline at %s", highres.scale)

if chain is None:
chain = ChainPipeline()

if highres.iterations < 1:
logger.debug("no highres iterations, skipping")
return chain

if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
stage,
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
)
else:
logger.debug("using simple upscaling for highres")
chain.stage(
UpscaleSimpleStage(),
stage,
overlap=params.overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)

chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
strength=highres.strength,
)

return chain
File renamed without changes.
4 changes: 3 additions & 1 deletion api/onnx_web/chain/upscale.py
Expand Up @@ -44,13 +44,14 @@ def stage_upscale_correction(
chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
**kwargs,
) -> ChainPipeline:
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
"""
logger.info(
"running upscaling and correction pipeline at %s:%s",
"staging upscaling and correction pipeline at %s:%s",
upscale.scale,
upscale.outscale,
)
Expand All @@ -63,6 +64,7 @@ def stage_upscale_correction(
chain.append((stage, pre_params, pre_opts))

upscale_opts = {
**kwargs,
"upscale": upscale,
}
upscale_stage = None
Expand Down
38 changes: 4 additions & 34 deletions api/onnx_web/chain/upscale_highres.py
@@ -1,14 +1,13 @@
from logging import getLogger
from typing import Any, Optional
from typing import Optional

from PIL import Image

from ..chain import BlendImg2ImgStage, ChainPipeline
from ..chain.highres import stage_highres
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction

logger = getLogger(__name__)

Expand All @@ -18,14 +17,13 @@ def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
Expand All @@ -34,35 +32,7 @@ def run(
if highres.scale <= 1:
return source

chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale)

# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)

chain.stage(
BlendImg2ImgStage(),
StageParams(),
overlap=params.overlap,
strength=highres.strength,
)
chain = stage_highres(stage, params, highres, upscale)

return chain(
job,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Expand Up @@ -13,7 +13,7 @@
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import complete_tile, process_tile_grid, process_tile_order
from .tile import complete_tile, process_tile_grid, process_tile_order

logger = getLogger(__name__)

Expand Down
46 changes: 46 additions & 0 deletions api/onnx_web/chain/upscale_simple.py
@@ -0,0 +1,46 @@
from logging import getLogger
from typing import Optional

from PIL import Image

from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


class UpscaleSimpleStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
method: str,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source

if upscale.scale <= 1:
logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale
)
return source

scaled_size = (source.width * upscale.scale, source.height * upscale.scale)

if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.warning("unknown upscaling method: %s", method)

return source
74 changes: 38 additions & 36 deletions api/onnx_web/diffusers/run.py
Expand Up @@ -3,12 +3,13 @@

from PIL import Image

from onnx_web.chain.highres import stage_highres

from ..chain import (
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleHighresStage,
UpscaleOutpaintStage,
)
from ..chain.upscale import split_upscale, stage_upscale_correction
Expand Down Expand Up @@ -60,14 +61,12 @@ def run_txt2img_pipeline(

# apply highres
for _i in range(highres.iterations):
chain.stage(
UpscaleHighresStage(),
StageParams(
outscale=highres.scale,
),
highres=highres,
upscale=upscale,
overlap=params.overlap,
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
Expand Down Expand Up @@ -141,23 +140,22 @@ def run_img2img_pipeline(
)

# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
strength=strength,
)
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
strength=strength,
)

# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
stage_upscale_correction(
Expand Down Expand Up @@ -233,12 +231,14 @@ def run_inpaint_pipeline(
)

# apply highres
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction
stage_upscale_correction(
Expand Down Expand Up @@ -299,12 +299,14 @@ def run_upscale_pipeline(
)

# apply highres
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
stage_upscale_correction(
Expand Down

0 comments on commit a7be651

Please sign in to comment.