Skip to content

Commit

Permalink
fix(api): detect all mask keys, immediately bubble up cancellation er…
Browse files Browse the repository at this point in the history
…rors
  • Loading branch information
ssube committed Dec 5, 2023
1 parent b29837d commit 95a62b1
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions api/onnx_web/chain/pipeline.py
Expand Up @@ -5,7 +5,7 @@

from PIL import Image

from ..errors import RetryException
from ..errors import CancelledException, RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
Expand Down Expand Up @@ -146,7 +146,7 @@ def __call__(
kwargs.pop("params")

# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = "mask" in stage_kwargs or any(
must_tile = has_mask(stage_kwargs) or any(
[
needs_tile(
stage_pipe.max_tile,
Expand Down Expand Up @@ -192,6 +192,10 @@ def stage_tile(
save_image(server, f"last-tile-{j}.png", image)

return tile_result
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled while tiling")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
Expand Down Expand Up @@ -234,6 +238,10 @@ def stage_tile(
# does not like, so it throws
stage_sources = stage_result
break
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled during stage")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
Expand Down Expand Up @@ -264,3 +272,9 @@ def stage_tile(
len(stage_sources),
)
return stage_sources


MASK_KEYS = ["mask", "stage_mask", "tile_mask"]

def has_mask(args: List[str]) -> bool:
return any([key in args for key in MASK_KEYS])

0 comments on commit 95a62b1

Please sign in to comment.