Skip to content

Commit

Permalink
fix(api): restore separate upscale and correction stages
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 18, 2023
1 parent 118695d commit f534fbb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,29 +97,29 @@ def optimize_pipeline(
try:
pipe.enable_attention_slicing()
except Exception as e:
logger.warning("error enabling attention slicing: %s", e)
logger.warning("error while enabling attention slicing: %s", e)

if "vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.warning("error enabling VAE slicing: %s", e)
logger.warning("error while enabling VAE slicing: %s", e)

if "sequential-cpu-offload" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
try:
pipe.enable_sequential_cpu_offload()
except Exception as e:
logger.warning("error enabling sequential CPU offload: %s", e)
logger.warning("error while enabling sequential CPU offload: %s", e)

elif "model-cpu-offload" in server.optimizations:
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
try:
pipe.enable_model_cpu_offload()
except Exception as e:
logger.warning("error enabling model CPU offload: %s", e)
logger.warning("error while enabling model CPU offload: %s", e)


if "memory-efficient-attention" in server.optimizations:
Expand All @@ -128,7 +128,7 @@ def optimize_pipeline(
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning("error enabling memory efficient attention: %s", e)
logger.warning("error while enabling memory efficient attention: %s", e)


def load_pipeline(
Expand Down
14 changes: 7 additions & 7 deletions api/onnx_web/server/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_upscale_correction(

chain = ChainPipeline()

upscale_stage = None
if upscale.scale > 1:
if "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
Expand All @@ -42,23 +43,22 @@ def run_upscale_correction(
upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
upscale_stage = None

correct_stage = None
if upscale.faces:
face_stage = StageParams(
face_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_stage, None)
correct_stage = (correct_codeformer, face_params, None)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_stage, None)
correct_stage = (correct_gfpgan, face_params, None)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
correct_stage = None

if upscale.upscale_order == "correction-both":
chain.append(correct_stage)
Expand Down

0 comments on commit f534fbb

Please sign in to comment.