Skip to content

Commit

Permalink
feat: split up UNet and VAE tile size and overlap/stride params
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 5, 2023
1 parent e9b1375 commit e8d7d9a
Show file tree
Hide file tree
Showing 19 changed files with 201 additions and 148 deletions.
6 changes: 3 additions & 3 deletions api/onnx_web/chain/highres.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ def stage_highres(
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
overlap=params.vae_overlap,
)
else:
logger.debug("using simple upscaling for highres")
chain.stage(
UpscaleSimpleStage(),
stage,
method=highres.method,
overlap=params.overlap,
overlap=params.vae_overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)

chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
overlap=params.vae_overlap,
prompt_index=prompt_index + i,
strength=highres.strength,
)
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def run(
)

if params.is_xl():
tile_size = max(stage.tile_size, params.tiles)
tile_size = max(stage.tile_size, params.unet_tile)
else:
tile_size = params.tiles
tile_size = params.unet_tile

# this works for panorama as well, because tile_size is already max(tile_size, *size)
latent_size = size.min(tile_size, tile_size)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_bsrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def steps(
params: ImageParams,
size: Size,
) -> int:
tile = min(params.tiles, self.max_tile)
tile = min(params.unet_tile, self.max_tile)
return size.width // tile * size.height // tile
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run(
outputs.append(source)
continue

tile_size = params.tiles
tile_size = params.unet_tile
size = Size(*source.size)
latent_size = size.min(tile_size, tile_size)

Expand Down
17 changes: 8 additions & 9 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,13 @@ def load_pipeline(

# update panorama params
if params.is_panorama():
latent_window = params.tiles // 8
latent_stride = params.stride // 8

pipe.set_window_size(latent_window, latent_stride)
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8
logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap)
pipe.set_window_size(params.unet_tile // 8, unet_stride)

for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
getattr(pipe, vae).set_window_size(latent_window, params.overlap)
getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap)

run_gc([device])

Expand Down Expand Up @@ -626,8 +625,8 @@ def patch_pipeline(
server,
original_decoder,
decoder=True,
window=params.tiles,
overlap=params.overlap,
window=params.unet_tile,
overlap=params.vae_overlap,
)
logger.debug("patched VAE decoder with wrapper")

Expand All @@ -637,8 +636,8 @@ def patch_pipeline(
server,
original_encoder,
decoder=False,
window=params.tiles,
overlap=params.overlap,
window=params.unet_tile,
overlap=params.vae_overlap,
)
logger.debug("patched VAE encoder with wrapper")

Expand Down
18 changes: 9 additions & 9 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def run_txt2img_pipeline(
) -> None:
# if using panorama, the pipeline will tile itself (views)
if params.is_panorama() or params.is_xl():
tile_size = max(params.tiles, size.width, size.height)
tile_size = max(params.unet_tile, size.width, size.height)
else:
tile_size = params.tiles
tile_size = params.unet_tile

# prepare the chain pipeline and first stage
chain = ChainPipeline()
Expand All @@ -57,11 +57,11 @@ def run_txt2img_pipeline(
),
size=size,
prompt_index=0,
overlap=params.overlap,
overlap=params.vae_overlap,
)

# apply upscaling and correction, before highres
stage = StageParams(tile_size=params.tiles)
stage = StageParams(tile_size=params.unet_tile)
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
Expand Down Expand Up @@ -139,14 +139,14 @@ def run_img2img_pipeline(
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams(
tile_size=params.tiles,
tile_size=params.unet_tile,
)
chain.stage(
BlendImg2ImgStage(),
stage,
prompt_index=0,
strength=strength,
overlap=params.overlap,
overlap=params.vae_overlap,
)

# apply upscaling and correction, before highres
Expand Down Expand Up @@ -236,7 +236,7 @@ def run_inpaint_pipeline(
full_res_inpaint_padding: float,
) -> None:
logger.debug("building inpaint pipeline")
tile_size = params.tiles
tile_size = params.unet_tile

if mask is None:
# if no mask was provided, keep the full source image
Expand Down Expand Up @@ -332,7 +332,7 @@ def run_inpaint_pipeline(
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
overlap=params.overlap,
overlap=params.vae_overlap,
prompt_index=0,
)

Expand Down Expand Up @@ -410,7 +410,7 @@ def run_upscale_pipeline(
) -> None:
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams(tile_size=params.tiles)
stage = StageParams(tile_size=params.unet_tile)

# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
Expand Down
34 changes: 20 additions & 14 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ class ImageParams:
input_negative_prompt: str
loopback: int
tiled_vae: bool
tiles: int
overlap: float
unet_tile: int
unet_overlap: float
vae_tile: int
vae_overlap: float

def __init__(
self,
Expand All @@ -224,9 +226,10 @@ def __init__(
input_negative_prompt: Optional[str] = None,
loopback: int = 0,
tiled_vae: bool = False,
tiles: int = 512,
overlap: float = 0.25,
stride: int = 64,
unet_overlap: float = 0.25,
unet_tile: int = 512,
vae_overlap: float = 0.25,
vae_tile: int = 512,
) -> None:
self.model = model
self.pipeline = pipeline
Expand All @@ -243,9 +246,10 @@ def __init__(
self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback
self.tiled_vae = tiled_vae
self.tiles = tiles
self.overlap = overlap
self.stride = stride
self.unet_overlap = unet_overlap
self.unet_tile = unet_tile
self.vae_overlap = vae_overlap
self.vae_tile = vae_tile

def do_cfg(self):
return self.cfg > 1.0
Expand Down Expand Up @@ -312,9 +316,10 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback,
"tiled_vae": self.tiled_vae,
"tiles": self.tiles,
"overlap": self.overlap,
"stride": self.stride,
"unet_overlap": self.unet_overlap,
"unet_tile": self.unet_tile,
"vae_overlap": self.vae_overlap,
"vae_tile": self.vae_tile,
}

def with_args(self, **kwargs):
Expand All @@ -334,9 +339,10 @@ def with_args(self, **kwargs):
kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback),
kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles),
kwargs.get("overlap", self.overlap),
kwargs.get("stride", self.stride),
kwargs.get("unet_overlap", self.unet_overlap),
kwargs.get("unet_tile", self.unet_tile),
kwargs.get("vae_overlap", self.vae_overlap),
kwargs.get("vae_tile", self.vae_tile),
)


Expand Down
50 changes: 27 additions & 23 deletions api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,35 @@ def build_params(
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
unet_overlap = get_and_clamp_float(
data,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
"unet_overlap",
get_config_value("unet_overlap"),
get_config_value("unet_overlap", "max"),
get_config_value("unet_overlap", "min"),
)
overlap = get_and_clamp_float(
unet_tile = get_and_clamp_int(
data,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
"unet_tile",
get_config_value("unet_tile"),
get_config_value("unet_tile", "max"),
get_config_value("unet_tile", "min"),
)
stride = get_and_clamp_int(
vae_overlap = get_and_clamp_float(
data,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
"vae_overlap",
get_config_value("vae_overlap"),
get_config_value("vae_overlap", "max"),
get_config_value("vae_overlap", "min"),
)
vae_tile = get_and_clamp_int(
data,
"vae_tile",
get_config_value("vae_tile"),
get_config_value("vae_tile", "max"),
get_config_value("vae_tile", "min"),
)

if stride > tiles:
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles

seed = int(data.get("seed", -1))
if seed == -1:
Expand All @@ -163,9 +166,10 @@ def build_params(
control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
unet_overlap=unet_overlap,
unet_tile=unet_tile,
vae_overlap=vae_overlap,
vae_tile=vae_tile,
)

return params
Expand Down
44 changes: 25 additions & 19 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,6 @@
"max": 4,
"step": 1
},
"overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"pipeline": {
"default": "",
"keys": [
Expand Down Expand Up @@ -197,21 +191,9 @@
"max": 1,
"step": 0.01
},
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": {
"tiled_vae": {
"default": false
},
"tiles": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"tileOrder": {
"default": "spiral",
"keys": [
Expand All @@ -225,6 +207,18 @@
"max": 1024,
"step": 8
},
"unet_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"unet_tile": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"upscaleOrder": {
"default": "correction-first",
"keys": [
Expand All @@ -237,6 +231,18 @@
"default": "",
"keys": []
},
"vae_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"vae_tile": {
"default": 512,
"min": 256,
"max": 1024,
"step": 128
},
"width": {
"default": 512,
"min": 128,
Expand Down
4 changes: 2 additions & 2 deletions api/scripts/test-release.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def __init__(
),
TestCase(
"txt2img-panorama-1024x768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true",
max_attempts=VERY_SLOW_TEST,
),
TestCase(
"img2img-panorama-1024x768-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true",
source="txt2img-panorama-1024x768-muffin-0",
max_attempts=VERY_SLOW_TEST,
),
Expand Down

0 comments on commit e8d7d9a

Please sign in to comment.