Skip to content

Commit

Permalink
feat(api): pass tile size param to most pipeline stages
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 2, 2023
1 parent c515d25 commit d8ec93a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
4 changes: 3 additions & 1 deletion api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..params import ImageParams, Size, StageParams
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
Expand All @@ -16,6 +16,8 @@


class SourceTxt2ImgStage(BaseStage):
max_tile = SizeChart.unlimited

def run(
self,
job: WorkerContext,
Expand Down
12 changes: 8 additions & 4 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def run_txt2img_pipeline(
) -> None:
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
stage = StageParams(
tile_size=params.tiles,
)
chain.stage(
SourceTxt2ImgStage(),
stage,
Expand Down Expand Up @@ -122,7 +124,9 @@ def run_img2img_pipeline(

# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
stage = StageParams(
tile_size=params.tiles,
)
chain.stage(
BlendImg2ImgStage(),
stage,
Expand Down Expand Up @@ -219,7 +223,7 @@ def run_inpaint_pipeline(

# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
stage = StageParams(tile_order=tile_order, tile_size=params.tiles)
chain.stage(
UpscaleOutpaintStage(),
stage,
Expand Down Expand Up @@ -286,7 +290,7 @@ def run_upscale_pipeline(
) -> None:
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams()
stage = StageParams(tile_size=params.tiles)

# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


class SizeChart(IntEnum):
unlimited = 0
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
auto = 512 # auto tile size
Expand All @@ -22,6 +23,7 @@ class SizeChart(IntEnum):
hd4k = 2**12
hd8k = 2**13
hd16k = 2**14
hd32k = 2**15
hd64k = 2**16


Expand Down

0 comments on commit d8ec93a

Please sign in to comment.