Skip to content

Commit

Permalink
fix(api): tile stages based on input image or size param
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 2, 2023
1 parent b8aef2c commit c9a1ace
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
15 changes: 9 additions & 6 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import process_tile_order
from .tile import needs_tile, process_tile_order

logger = getLogger(__name__)

Expand Down Expand Up @@ -149,13 +149,16 @@ def __call__(
"running stage %s without source image, %s", name, kwargs.keys()
)

if image is not None and (
image.width > stage_params.tile_size
or image.height > stage_params.tile_size
if needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=image,
):
tile = min(stage_pipe.max_tile, stage_params.tile_size)
logger.info(
"image larger than tile size of %s, tiling stage",
stage_params.tile_size,
tile,
)

def stage_tile(tile: Image.Image, _dims) -> Image.Image:
Expand All @@ -177,7 +180,7 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:
image = process_tile_order(
stage_params.tile_order,
image,
stage_params.tile_size,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
Expand Down
21 changes: 19 additions & 2 deletions api/onnx_web/chain/tile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from logging import getLogger
from math import ceil
from typing import List, Protocol, Tuple
from typing import List, Optional, Protocol, Tuple

import numpy as np
from PIL import Image

from ..params import TileOrder
from ..params import Size, TileOrder

# from skimage.exposure import match_histograms

Expand Down Expand Up @@ -37,6 +37,23 @@ def complete_tile(
return source


def needs_tile(
max_tile: int,
stage_tile: int,
size: Optional[Size] = None,
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)

if source is not None:
return source.width > tile or source.height > tile

if size is not None:
return size.width > tile or size.height > tile

return False


def get_tile_grads(
left: int,
top: int,
Expand Down

0 comments on commit c9a1ace

Please sign in to comment.