Skip to content

Commit

Permalink
fix(api): pass correct outscale to highres stages
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 25, 2023
1 parent b1328fd commit 6ecdae4
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/highres.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def stage_highres(

chain.stage(
BlendImg2ImgStage(),
stage,
stage.with_args(outscale=1),
overlap=params.vae_overlap,
prompt_index=prompt_index + i,
strength=highres.strength,
Expand Down
1 change: 0 additions & 1 deletion api/onnx_web/chain/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def blend_tiles(
value = np.zeros(scaled_size)

for left, top, tile_image in tiles:
# TODO: histogram equalization
equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0])

Expand Down
17 changes: 8 additions & 9 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,28 @@ def run_txt2img_pipeline(

# apply upscaling and correction, before highres
highres_size = params.unet_tile
stage = StageParams(tile_size=highres_size)

if params.is_panorama():
if server.has_feature("panorama-highres"):
# run the whole highres pass with one panorama call
highres_size = tile_size * highres.scale

chain.stage(
BlendDenoiseStage(),
stage,
StageParams(tile_size=highres_size),
)

if server.has_feature("panorama-highres"):
highres_size = tile_size * highres.scale

first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=first_upscale,
)

# apply highres
stage_highres(
stage,
StageParams(outscale=highres.scale, tile_size=highres_size),
params,
highres,
upscale,
Expand All @@ -96,7 +95,7 @@ def run_txt2img_pipeline(

# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=after_upscale,
Expand Down
11 changes: 11 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def __init__(
self.tile_order = tile_order
self.tile_size = tile_size

def with_args(
self,
**kwargs,
):
return StageParams(
name=kwargs.get("name", self.name),
outscale=kwargs.get("outscale", self.outscale),
tile_order=kwargs.get("tile_order", self.tile_order),
tile_size=kwargs.get("tile_size", self.tile_size),
)


class UpscaleParams:
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion api/tests/test_diffusers/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def test_highres(self):
3.0,
1,
1,
unet_tile=256,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test"),
UpscaleParams("test", scale=2, outscale=2),
HighresParams(True, 2, 0, 0),
)

Expand Down

0 comments on commit 6ecdae4

Please sign in to comment.