Skip to content

Commit

Permalink
feat(api): enable 1x upscaling models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 30, 2023
1 parent 11e643b commit 7abe6dc
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions api/onnx_web/chain/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,30 @@ def stage_upscale_correction(
"upscale": upscale,
}
upscale_stage: Optional[PipelineStage] = None
if upscale.scale > 1:
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warning("unknown upscaling model: %s", upscale.upscale_model)
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warning("unknown upscaling model: %s", upscale.upscale_model)

correct_stage: Optional[PipelineStage] = None
if upscale.faces:
Expand Down

0 comments on commit 7abe6dc

Please sign in to comment.