diff --git a/docs/detection/annotators.md b/docs/detection/annotators.md index 938c49ff4..a0341eddf 100644 --- a/docs/detection/annotators.md +++ b/docs/detection/annotators.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Annotators diff --git a/docs/detection/core.md b/docs/detection/core.md index 475cdae1d..35225cec5 100644 --- a/docs/detection/core.md +++ b/docs/detection/core.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Detections diff --git a/docs/detection/utils/boxes.md b/docs/detection/utils/boxes.md index 63a323175..020cc8f99 100644 --- a/docs/detection/utils/boxes.md +++ b/docs/detection/utils/boxes.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Boxes Utils diff --git a/docs/detection/utils/iou_and_nms.md b/docs/detection/utils/iou_and_nms.md index 2b4e4fc33..7191656b7 100644 --- a/docs/detection/utils/iou_and_nms.md +++ b/docs/detection/utils/iou_and_nms.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # IoU and NMS Utils diff --git a/docs/detection/utils/polygons.md b/docs/detection/utils/polygons.md index cd9525345..8a7cf1e1c 100644 --- a/docs/detection/utils/polygons.md +++ b/docs/detection/utils/polygons.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Polygons Utils diff --git a/docs/how_to/benchmark_a_model.md b/docs/how_to/benchmark_a_model.md index bf23ee089..aa707fa73 100644 --- a/docs/how_to/benchmark_a_model.md +++ b/docs/how_to/benchmark_a_model.md @@ -1,6 +1,5 @@ --- comments: true -status: new ---  diff --git a/docs/how_to/track_objects.md b/docs/how_to/track_objects.md index 9bf17e865..2acad7740 100644 --- a/docs/how_to/track_objects.md +++ b/docs/how_to/track_objects.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Track Objects diff --git a/docs/keypoint/core.md b/docs/keypoint/core.md index acb13e156..e683ae873 100644 --- a/docs/keypoint/core.md +++ b/docs/keypoint/core.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Keypoint Detection diff --git a/docs/metrics/mean_average_precision.md b/docs/metrics/mean_average_precision.md index ce3e06a41..10f7a9777 100644 --- a/docs/metrics/mean_average_precision.md +++ b/docs/metrics/mean_average_precision.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Mean Average Precision diff --git a/docs/utils/image.md b/docs/utils/image.md index 17d94eac3..9d1c1895c 100644 --- a/docs/utils/image.md +++ b/docs/utils/image.md @@ -41,6 +41,12 @@ status: new :::supervision.utils.image.grayscale_image +
+ +:::supervision.utils.image.get_image_resolution_wh + diff --git a/pyproject.toml b/pyproject.toml index 81c12fb62..2996af1fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "supervision" description = "A set of easy-to-use utils that will come in handy in any Computer Vision project" license = { text = "MIT" } -version = "0.27.0rc4" +version = "0.27.0rc5" readme = "README.md" requires-python = ">=3.9" authors = [ diff --git a/supervision/__init__.py b/supervision/__init__.py index ccd272930..00820076f 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -123,6 +123,7 @@ from supervision.utils.image import ( ImageSink, crop_image, + get_image_resolution_wh, grayscale_image, letterbox_image, overlay_image, @@ -223,6 +224,7 @@ "filter_segments_by_distance", "fuzzy_match_index", "get_coco_class_index_mapping", + "get_image_resolution_wh", "get_polygon_center", "get_video_frames_generator", "grayscale_image", diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py index aaecccb3d..103b3fa46 100644 --- a/supervision/detection/tools/inference_slicer.py +++ b/supervision/detection/tools/inference_slicer.py @@ -11,11 +11,9 @@ from supervision.detection.utils.boxes import move_boxes, move_oriented_boxes from supervision.detection.utils.iou_and_nms import OverlapFilter, OverlapMetric from supervision.detection.utils.masks import move_masks -from supervision.utils.image import crop_image -from supervision.utils.internal import ( - SupervisionWarnings, - warn_deprecated, -) +from supervision.draw.base import ImageType +from supervision.utils.image import crop_image, get_image_resolution_wh +from supervision.utils.internal import SupervisionWarnings def move_detections( @@ -53,111 +51,102 @@ def move_detections( class InferenceSlicer: """ - InferenceSlicer performs slicing-based inference for small target detection. This - method, often referred to as - [Slicing Adaptive Inference (SAHI)](https://ieeexplore.ieee.org/document/9897990), - involves dividing a larger image into smaller slices, performing inference on each - slice, and then merging the detections. + Perform tiled inference on large images by slicing them into overlapping patches. + + This class divides an input image into overlapping slices of configurable size + and overlap, runs inference on each slice through a user-provided callback, and + merges the resulting detections. The slicing process allows efficient processing + of large images with limited resources while preserving detection accuracy via + configurable overlap and post-processing of overlaps. Uses multi-threading for + parallel slice inference. Args: - slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The - tuple should be in the format `(width, height)`. - overlap_ratio_wh (Optional[Tuple[float, float]]): [⚠️ Deprecated: please set - to `None` and use `overlap_wh`] A tuple representing the - desired overlap ratio for width and height between consecutive slices. - Each value should be in the range [0, 1), where 0 means no overlap and - a value close to 1 means high overlap. - overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired - overlap for width and height between consecutive slices measured in pixels. - Each value should be greater than or equal to 0. Takes precedence over - `overlap_ratio_wh`. - overlap_filter (Union[OverlapFilter, str]): Strategy for - filtering or merging overlapping detections in slices. - iou_threshold (float): Intersection over Union (IoU) threshold - used when filtering by overlap. - overlap_metric (Union[OverlapMetric, str]): Metric used for matching detections - in slices. - callback (Callable): A function that performs inference on a given image - slice and returns detections. - thread_workers (int): Number of threads for parallel execution. - - Note: - The class ensures that slices do not exceed the boundaries of the original - image. As a result, the final slices in the row and column dimensions might be - smaller than the specified slice dimensions if the image's width or height is - not a multiple of the slice's width or height minus the overlap. + callback (Callable[[ImageType], Detections]): Inference function that takes + a sliced image and returns a `Detections` object. + slice_wh (int or tuple[int, int]): Size of each slice `(width, height)`. + If int, both width and height are set to this value. + overlap_wh (int or tuple[int, int]): Overlap size `(width, height)` between + slices. If int, both width and height are set to this value. + overlap_filter (OverlapFilter or str): Strategy to merge overlapping + detections (`NON_MAX_SUPPRESSION`, `NON_MAX_MERGE`, or `NONE`). + iou_threshold (float): IOU threshold used in merging overlap filtering. + overlap_metric (OverlapMetric or str): Metric to compute overlap + (`IOU` or `IOS`). + thread_workers (int): Number of threads for concurrent slice inference. + + Raises: + ValueError: If `slice_wh` or `overlap_wh` are invalid or inconsistent. + + Example: + ```python + import cv2 + import supervision as sv + from rfdetr import RFDETRMedium + + def callback(tile): + return model.predict(tile) + + slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100) + + image = cv2.imread("example.png") + detections = slicer(image) + ``` + + ```python + import supervision as sv + from PIL import Image + from ultralytics import YOLO + + def callback(tile): + results = model(tile)[0] + return sv.Detections.from_ultralytics(results) + + slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100) + + image = Image.open("example.png") + detections = slicer(image) + ``` """ def __init__( self, - callback: Callable[[np.ndarray], Detections], - slice_wh: tuple[int, int] = (320, 320), - overlap_ratio_wh: tuple[float, float] | None = (0.2, 0.2), - overlap_wh: tuple[int, int] | None = None, + callback: Callable[[ImageType], Detections], + slice_wh: int | tuple[int, int] = 640, + overlap_wh: int | tuple[int, int] = 100, overlap_filter: OverlapFilter | str = OverlapFilter.NON_MAX_SUPPRESSION, iou_threshold: float = 0.5, overlap_metric: OverlapMetric | str = OverlapMetric.IOU, thread_workers: int = 1, ): - if overlap_ratio_wh is not None: - warn_deprecated( - "`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and " - "will be removed in `supervision-0.27.0`. Please manually set it to " - "`None` and use `overlap_wh` instead." - ) + slice_wh_norm = self._normalize_slice_wh(slice_wh) + overlap_wh_norm = self._normalize_overlap_wh(overlap_wh) - self._validate_overlap(overlap_ratio_wh, overlap_wh) - self.overlap_ratio_wh = overlap_ratio_wh - self.overlap_wh = overlap_wh + self._validate_overlap(slice_wh=slice_wh_norm, overlap_wh=overlap_wh_norm) - self.slice_wh = slice_wh + self.slice_wh = slice_wh_norm + self.overlap_wh = overlap_wh_norm self.iou_threshold = iou_threshold self.overlap_metric = OverlapMetric.from_value(overlap_metric) self.overlap_filter = OverlapFilter.from_value(overlap_filter) self.callback = callback self.thread_workers = thread_workers - def __call__(self, image: np.ndarray) -> Detections: + def __call__(self, image: ImageType) -> Detections: """ - Performs slicing-based inference on the provided image using the specified - callback. + Perform tiled inference on the full image and return merged detections. Args: - image (np.ndarray): The input image on which inference needs to be - performed. The image should be in the format - `(height, width, channels)`. + image (ImageType): The full image to run inference on. Returns: - Detections: A collection of detections for the entire image after merging - results from all slices and applying NMS. - - Example: - ```python - import cv2 - import supervision as sv - from ultralytics import YOLO - - image = cv2.imread(SOURCE_IMAGE_PATH) - model = YOLO(...) - - def callback(image_slice: np.ndarray) -> sv.Detections: - result = model(image_slice)[0] - return sv.Detections.from_ultralytics(result) - - slicer = sv.InferenceSlicer( - callback=callback, - overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION, - ) - - detections = slicer(image) - ``` + Detections: Merged detections across all slices. """ - detections_list = [] - resolution_wh = (image.shape[1], image.shape[0]) + detections_list: list[Detections] = [] + resolution_wh = get_image_resolution_wh(image) + offsets = self._generate_offset( resolution_wh=resolution_wh, slice_wh=self.slice_wh, - overlap_ratio_wh=self.overlap_ratio_wh, overlap_wh=self.overlap_wh, ) @@ -171,129 +160,178 @@ def callback(image_slice: np.ndarray) -> sv.Detections: merged = Detections.merge(detections_list=detections_list) if self.overlap_filter == OverlapFilter.NONE: return merged - elif self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION: + if self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION: return merged.with_nms( - threshold=self.iou_threshold, overlap_metric=self.overlap_metric + threshold=self.iou_threshold, + overlap_metric=self.overlap_metric, ) - elif self.overlap_filter == OverlapFilter.NON_MAX_MERGE: + if self.overlap_filter == OverlapFilter.NON_MAX_MERGE: return merged.with_nmm( - threshold=self.iou_threshold, overlap_metric=self.overlap_metric + threshold=self.iou_threshold, + overlap_metric=self.overlap_metric, ) - else: - warnings.warn( - f"Invalid overlap filter strategy: {self.overlap_filter}", - category=SupervisionWarnings, - ) - return merged - def _run_callback(self, image, offset) -> Detections: + warnings.warn( + f"Invalid overlap filter strategy: {self.overlap_filter}", + category=SupervisionWarnings, + ) + return merged + + def _run_callback(self, image: ImageType, offset: np.ndarray) -> Detections: """ - Run the provided callback on a slice of an image. + Run detection callback on a sliced portion of the image and adjust coordinates. Args: - image (np.ndarray): The input image on which inference needs to run - offset (np.ndarray): An array of shape `(4,)` containing coordinates - for the slice. + image (ImageType): The full image. + offset (numpy.ndarray): Coordinates `(x_min, y_min, x_max, y_max)` defining + the slice region. Returns: - Detections: A collection of detections for the slice. + Detections: Detections adjusted to the full image coordinate system. """ - image_slice = crop_image(image=image, xyxy=offset) + image_slice: ImageType = crop_image(image=image, xyxy=offset) detections = self.callback(image_slice) - resolution_wh = (image.shape[1], image.shape[0]) + resolution_wh = get_image_resolution_wh(image) + detections = move_detections( - detections=detections, offset=offset[:2], resolution_wh=resolution_wh + detections=detections, + offset=offset[:2], + resolution_wh=resolution_wh, ) - return detections + @staticmethod + def _normalize_slice_wh( + slice_wh: int | tuple[int, int], + ) -> tuple[int, int]: + if isinstance(slice_wh, int): + if slice_wh <= 0: + raise ValueError( + f"`slice_wh` must be a positive integer. Received: {slice_wh}" + ) + return slice_wh, slice_wh + + if isinstance(slice_wh, tuple) and len(slice_wh) == 2: + width, height = slice_wh + if width <= 0 or height <= 0: + raise ValueError( + f"`slice_wh` values must be positive. Received: {slice_wh}" + ) + return width, height + + raise ValueError( + "`slice_wh` must be an int or a tuple of two positive integers " + "(slice_w, slice_h). " + f"Received: {slice_wh}" + ) + + @staticmethod + def _normalize_overlap_wh( + overlap_wh: int | tuple[int, int], + ) -> tuple[int, int]: + if isinstance(overlap_wh, int): + if overlap_wh < 0: + raise ValueError( + "`overlap_wh` must be a non negative integer. " + f"Received: {overlap_wh}" + ) + return overlap_wh, overlap_wh + + if isinstance(overlap_wh, tuple) and len(overlap_wh) == 2: + overlap_w, overlap_h = overlap_wh + if overlap_w < 0 or overlap_h < 0: + raise ValueError( + f"`overlap_wh` values must be non negative. Received: {overlap_wh}" + ) + return overlap_w, overlap_h + + raise ValueError( + "`overlap_wh` must be an int or a tuple of two non negative integers " + "(overlap_w, overlap_h). " + f"Received: {overlap_wh}" + ) + @staticmethod def _generate_offset( resolution_wh: tuple[int, int], slice_wh: tuple[int, int], - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, + overlap_wh: tuple[int, int], ) -> np.ndarray: """ - Generate offset coordinates for slicing an image based on the given resolution, - slice dimensions, and overlap ratios. + Generate bounding boxes defining the coordinates of image slices with overlap. Args: - resolution_wh (Tuple[int, int]): A tuple representing the width and height - of the image to be sliced. - slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The - tuple should be in the format `(width, height)`. - overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the - desired overlap ratio for width and height between consecutive slices. - Each value should be in the range [0, 1), where 0 means no overlap and - a value close to 1 means high overlap. - overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired - overlap for width and height between consecutive slices measured in - pixels. Each value should be greater than or equal to 0. + resolution_wh (tuple[int, int]): Image resolution `(width, height)`. + slice_wh (tuple[int, int]): Size of each slice `(width, height)`. + overlap_wh (tuple[int, int]): Overlap size between slices `(width, height)`. Returns: - np.ndarray: An array of shape `(n, 4)` containing coordinates for each - slice in the format `[xmin, ymin, xmax, ymax]`. - - Note: - The function ensures that slices do not exceed the boundaries of the - original image. As a result, the final slices in the row and column - dimensions might be smaller than the specified slice dimensions if the - image's width or height is not a multiple of the slice's width or - height minus the overlap. + numpy.ndarray: Array of shape `(num_slices, 4)` with each row as + `(x_min, y_min, x_max, y_max)` coordinates for a slice. """ slice_width, slice_height = slice_wh image_width, image_height = resolution_wh - overlap_width = ( - overlap_wh[0] - if overlap_wh is not None - else int(overlap_ratio_wh[0] * slice_width) + overlap_width, overlap_height = overlap_wh + + stride_x = slice_width - overlap_width + stride_y = slice_height - overlap_height + + def _compute_axis_starts( + image_size: int, + slice_size: int, + stride: int, + ) -> list[int]: + if image_size <= slice_size: + return [0] + + if stride == slice_size: + return np.arange(0, image_size, stride).tolist() + + last_start = image_size - slice_size + starts = np.arange(0, last_start, stride).tolist() + if not starts or starts[-1] != last_start: + starts.append(last_start) + return starts + + x_starts = _compute_axis_starts( + image_size=image_width, + slice_size=slice_width, + stride=stride_x, ) - overlap_height = ( - overlap_wh[1] - if overlap_wh is not None - else int(overlap_ratio_wh[1] * slice_height) + y_starts = _compute_axis_starts( + image_size=image_height, + slice_size=slice_height, + stride=stride_y, ) - width_stride = slice_width - overlap_width - height_stride = slice_height - overlap_height + x_min, y_min = np.meshgrid(x_starts, y_starts) + x_max = np.clip(x_min + slice_width, 0, image_width) + y_max = np.clip(y_min + slice_height, 0, image_height) - ws = np.arange(0, image_width, width_stride) - hs = np.arange(0, image_height, height_stride) - - xmin, ymin = np.meshgrid(ws, hs) - xmax = np.clip(xmin + slice_width, 0, image_width) - ymax = np.clip(ymin + slice_height, 0, image_height) - - offsets = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4) + offsets = np.stack( + [x_min, y_min, x_max, y_max], + axis=-1, + ).reshape(-1, 4) return offsets @staticmethod def _validate_overlap( - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, + slice_wh: tuple[int, int], + overlap_wh: tuple[int, int], ) -> None: - if overlap_ratio_wh is not None and overlap_wh is not None: + overlap_w, overlap_h = overlap_wh + slice_w, slice_h = slice_wh + + if overlap_w < 0 or overlap_h < 0: raise ValueError( - "Both `overlap_ratio_wh` and `overlap_wh` cannot be provided. " - "Please provide only one of them." + "Overlap values must be greater than or equal to 0. " + f"Received: {overlap_wh}" ) - if overlap_ratio_wh is None and overlap_wh is None: + + if overlap_w >= slice_w or overlap_h >= slice_h: raise ValueError( - "Either `overlap_ratio_wh` or `overlap_wh` must be provided. " - "Please provide one of them." + "`overlap_wh` must be smaller than `slice_wh` in both dimensions " + f"to keep a positive stride. Received overlap_wh={overlap_wh}, " + f"slice_wh={slice_wh}." ) - - if overlap_ratio_wh is not None: - if not (0 <= overlap_ratio_wh[0] < 1 and 0 <= overlap_ratio_wh[1] < 1): - raise ValueError( - "Overlap ratios must be in the range [0, 1). " - f"Received: {overlap_ratio_wh}" - ) - if overlap_wh is not None: - if not (overlap_wh[0] >= 0 and overlap_wh[1] >= 0): - raise ValueError( - "Overlap values must be greater than or equal to 0. " - f"Received: {overlap_wh}" - ) diff --git a/supervision/detection/utils/__init__.py b/supervision/detection/utils/__init__.py index c4b0f3870..e69de29bb 100644 --- a/supervision/detection/utils/__init__.py +++ b/supervision/detection/utils/__init__.py @@ -1,3 +0,0 @@ -from supervision.detection.utils.iou_and_nms import box_iou_batch - -__all__ = ["box_iou_batch"] diff --git a/supervision/utils/image.py b/supervision/utils/image.py index e8931f219..4d4348d69 100644 --- a/supervision/utils/image.py +++ b/supervision/utils/image.py @@ -6,6 +6,7 @@ import cv2 import numpy as np import numpy.typing as npt +from PIL import Image from supervision.annotators.base import ImageType from supervision.draw.color import Color, unify_to_bgr @@ -15,7 +16,6 @@ from supervision.utils.internal import deprecated -@ensure_cv2_image_for_standalone_function def crop_image( image: ImageType, xyxy: npt.NDArray[int] | list[int] | tuple[int, int, int, int], @@ -65,9 +65,19 @@ def crop_image( """ # noqa E501 // docs if isinstance(xyxy, (list, tuple)): xyxy = np.array(xyxy) + xyxy = np.round(xyxy).astype(int) x_min, y_min, x_max, y_max = xyxy.flatten() - return image[y_min:y_max, x_min:x_max] + + if isinstance(image, np.ndarray): + return image[y_min:y_max, x_min:x_max] + + if isinstance(image, Image.Image): + return image.crop((x_min, y_min, x_max, y_max)) + + raise TypeError( + f"`image` must be a numpy.ndarray or PIL.Image.Image. Received {type(image)}" + ) @ensure_cv2_image_for_standalone_function @@ -460,6 +470,62 @@ def grayscale_image(image: ImageType) -> ImageType: return cv2.cvtColor(grayscaled, cv2.COLOR_GRAY2BGR) +def get_image_resolution_wh(image: ImageType) -> tuple[int, int]: + """ + Get image width and height as a tuple `(width, height)` for various image formats. + + Supports both `numpy.ndarray` images (with shape `(H, W, ...)`) and + `PIL.Image.Image` inputs. + + Args: + image (`numpy.ndarray` or `PIL.Image.Image`): Input image. + + Returns: + (`tuple[int, int]`): Image resolution as `(width, height)`. + + Raises: + ValueError: If a `numpy.ndarray` image has fewer than 2 dimensions. + TypeError: If `image` is not a supported type (`numpy.ndarray` or + `PIL.Image.Image`). + + Examples: + ```python + import cv2 + import supervision as sv + + image = cv2.imread("example.png") + sv.get_image_resolution_wh(image) + # (1920, 1080) + ``` + + ```python + from PIL import Image + import supervision as sv + + image = Image.open("example.png") + sv.get_image_resolution_wh(image) + # (1920, 1080) + ``` + """ + if isinstance(image, np.ndarray): + if image.ndim < 2: + raise ValueError( + "NumPy image must have at least 2 dimensions (H, W, ...). " + f"Received shape: {image.shape}" + ) + height, width = image.shape[:2] + return int(width), int(height) + + if isinstance(image, Image.Image): + width, height = image.size + return int(width), int(height) + + raise TypeError( + "`image` must be a numpy.ndarray or PIL.Image.Image. " + f"Received type: {type(image)}" + ) + + class ImageSink: def __init__( self, diff --git a/test/detection/tools/test_inference_slicer.py b/test/detection/tools/test_inference_slicer.py index 2185b77f2..7c313841f 100644 --- a/test/detection/tools/test_inference_slicer.py +++ b/test/detection/tools/test_inference_slicer.py @@ -1,13 +1,10 @@ from __future__ import annotations -from contextlib import ExitStack as DoesNotRaise - import numpy as np import pytest from supervision.detection.core import Detections from supervision.detection.tools.inference_slicer import InferenceSlicer -from supervision.detection.utils.iou_and_nms import OverlapFilter @pytest.fixture @@ -20,54 +17,10 @@ def callback(_: np.ndarray) -> Detections: return callback -@pytest.mark.parametrize( - "slice_wh, overlap_ratio_wh, overlap_wh, expected_overlap, exception", - [ - # Valid case: explicit overlap_wh in pixels - ((128, 128), None, (26, 26), (26, 26), DoesNotRaise()), - # Valid case: overlap_wh in pixels - ((128, 128), None, (20, 20), (20, 20), DoesNotRaise()), - # Invalid case: negative overlap_wh, should raise ValueError - ((128, 128), None, (-10, 20), None, pytest.raises(ValueError)), - # Invalid case: no overlaps defined - ((128, 128), None, None, None, pytest.raises(ValueError)), - # Valid case: overlap_wh = 50 pixels - ((256, 256), None, (50, 50), (50, 50), DoesNotRaise()), - # Valid case: overlap_wh = 60 pixels - ((200, 200), None, (60, 60), (60, 60), DoesNotRaise()), - # Valid case: small overlap_wh values - ((100, 100), None, (0.1, 0.1), (0.1, 0.1), DoesNotRaise()), - # Invalid case: negative overlap_wh values - ((128, 128), None, (-10, -10), None, pytest.raises(ValueError)), - # Invalid case: overlap_wh greater than slice size - ((128, 128), None, (150, 150), (150, 150), DoesNotRaise()), - # Valid case: zero overlap - ((128, 128), None, (0, 0), (0, 0), DoesNotRaise()), - ], -) -def test_inference_slicer_overlap( - mock_callback, - slice_wh: tuple[int, int], - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, - expected_overlap: tuple[int, int] | None, - exception: Exception, -) -> None: - with exception: - slicer = InferenceSlicer( - callback=mock_callback, - slice_wh=slice_wh, - overlap_ratio_wh=overlap_ratio_wh, - overlap_wh=overlap_wh, - overlap_filter=OverlapFilter.NONE, - ) - assert slicer.overlap_wh == expected_overlap - - @pytest.mark.parametrize( "resolution_wh, slice_wh, overlap_wh, expected_offsets", [ - # Case 1: No overlap, exact slices fit within image dimensions + # Case 1: Square image, square slices, no overlap ( (256, 256), (128, 128), @@ -81,7 +34,7 @@ def test_inference_slicer_overlap( ] ), ), - # Case 2: Overlap of 64 pixels in both directions + # Case 2: Square image, square slices, non-zero overlap ( (256, 256), (128, 128), @@ -91,96 +44,154 @@ def test_inference_slicer_overlap( [0, 0, 128, 128], [64, 0, 192, 128], [128, 0, 256, 128], - [192, 0, 256, 128], [0, 64, 128, 192], [64, 64, 192, 192], [128, 64, 256, 192], - [192, 64, 256, 192], [0, 128, 128, 256], [64, 128, 192, 256], [128, 128, 256, 256], - [192, 128, 256, 256], - [0, 192, 128, 256], - [64, 192, 192, 256], - [128, 192, 256, 256], - [192, 192, 256, 256], ] ), ), - # Case 3: Image not perfectly divisible by slice size (no overlap) + # Case 3: Rectangle image (horizontal), square slices, no overlap ( - (300, 300), - (128, 128), + (192, 128), + (64, 64), (0, 0), np.array( [ - [0, 0, 128, 128], - [128, 0, 256, 128], - [256, 0, 300, 128], - [0, 128, 128, 256], - [128, 128, 256, 256], - [256, 128, 300, 256], - [0, 256, 128, 300], - [128, 256, 256, 300], - [256, 256, 300, 300], + [0, 0, 64, 64], + [64, 0, 128, 64], + [128, 0, 192, 64], + [0, 64, 64, 128], + [64, 64, 128, 128], + [128, 64, 192, 128], ] ), ), - # Case 4: Overlap of 32 pixels, image not perfectly divisible by slice size + # Case 4: Rectangle image (horizontal), square slices, non-zero overlap ( - (300, 300), - (128, 128), + (192, 128), + (64, 64), (32, 32), np.array( [ - [0, 0, 128, 128], - [96, 0, 224, 128], - [192, 0, 300, 128], - [288, 0, 300, 128], - [0, 96, 128, 224], - [96, 96, 224, 224], - [192, 96, 300, 224], - [288, 96, 300, 224], - [0, 192, 128, 300], - [96, 192, 224, 300], - [192, 192, 300, 300], - [288, 192, 300, 300], - [0, 288, 128, 300], - [96, 288, 224, 300], - [192, 288, 300, 300], - [288, 288, 300, 300], + [0, 0, 64, 64], + [32, 0, 96, 64], + [64, 0, 128, 64], + [96, 0, 160, 64], + [128, 0, 192, 64], + [0, 32, 64, 96], + [32, 32, 96, 96], + [64, 32, 128, 96], + [96, 32, 160, 96], + [128, 32, 192, 96], + [0, 64, 64, 128], + [32, 64, 96, 128], + [64, 64, 128, 128], + [96, 64, 160, 128], + [128, 64, 192, 128], ] ), ), - # Case 5: Image smaller than slice size (no overlap) + # Case 5: Rectangle image (vertical), square slices, no overlap ( - (100, 100), - (128, 128), + (128, 192), + (64, 64), (0, 0), np.array( [ - [0, 0, 100, 100], + [0, 0, 64, 64], + [64, 0, 128, 64], + [0, 64, 64, 128], + [64, 64, 128, 128], + [0, 128, 64, 192], + [64, 128, 128, 192], + ] + ), + ), + # Case 6: Rectangle image (vertical), square slices, non-zero overlap + ( + (128, 192), + (64, 64), + (32, 32), + np.array( + [ + [0, 0, 64, 64], + [32, 0, 96, 64], + [64, 0, 128, 64], + [0, 32, 64, 96], + [32, 32, 96, 96], + [64, 32, 128, 96], + [0, 64, 64, 128], + [32, 64, 96, 128], + [64, 64, 128, 128], + [0, 96, 64, 160], + [32, 96, 96, 160], + [64, 96, 128, 160], + [0, 128, 64, 192], + [32, 128, 96, 192], + [64, 128, 128, 192], + ] + ), + ), + # Case 7: Square image, rectangular slices (horizontal), no overlap + ( + (160, 160), + (80, 40), + (0, 0), + np.array( + [ + [0, 0, 80, 40], + [80, 0, 160, 40], + [0, 40, 80, 80], + [80, 40, 160, 80], + [0, 80, 80, 120], + [80, 80, 160, 120], + [0, 120, 80, 160], + [80, 120, 160, 160], + ] + ), + ), + # Case 8: Square image, rectangular slices (vertical), non-zero overlap + ( + (160, 160), + (40, 80), + (10, 20), + np.array( + [ + [0, 0, 40, 80], + [30, 0, 70, 80], + [60, 0, 100, 80], + [90, 0, 130, 80], + [120, 0, 160, 80], + [0, 60, 40, 140], + [30, 60, 70, 140], + [60, 60, 100, 140], + [90, 60, 130, 140], + [120, 60, 160, 140], + [0, 80, 40, 160], + [30, 80, 70, 160], + [60, 80, 100, 160], + [90, 80, 130, 160], + [120, 80, 160, 160], ] ), ), - # Case 6: Overlap_wh is greater than the slice size - ((256, 256), (128, 128), (150, 150), np.array([]).reshape(0, 4)), ], ) def test_generate_offset( resolution_wh: tuple[int, int], slice_wh: tuple[int, int], - overlap_wh: tuple[int, int] | None, + overlap_wh: tuple[int, int], expected_offsets: np.ndarray, ) -> None: offsets = InferenceSlicer._generate_offset( resolution_wh=resolution_wh, slice_wh=slice_wh, - overlap_ratio_wh=None, overlap_wh=overlap_wh, ) - # Verify that the generated offsets match the expected offsets assert np.array_equal(offsets, expected_offsets), ( f"Expected {expected_offsets}, got {offsets}" ) diff --git a/test/utils/test_image.py b/test/utils/test_image.py index 6ae9567b9..688f938b7 100644 --- a/test/utils/test_image.py +++ b/test/utils/test_image.py @@ -1,7 +1,13 @@ import numpy as np +import pytest from PIL import Image, ImageChops -from supervision.utils.image import letterbox_image, resize_image +from supervision.utils.image import ( + crop_image, + get_image_resolution_wh, + letterbox_image, + resize_image, +) def test_resize_image_for_opencv_image() -> None: @@ -94,3 +100,61 @@ def test_letterbox_image_for_pillow_image() -> None: assert difference.getbbox() is None, ( "Expected padding to be added top and bottom with padding added top and bottom" ) + + +@pytest.mark.parametrize( + "image, xyxy, expected_size", + [ + # NumPy RGB + ( + np.zeros((4, 6, 3), dtype=np.uint8), + (2, 1, 5, 3), + (3, 2), # width = 5-2, height = 3-1 + ), + # NumPy grayscale + ( + np.zeros((5, 5), dtype=np.uint8), + (1, 1, 4, 4), + (3, 3), + ), + # Pillow RGB + ( + Image.new("RGB", (6, 4), color=0), + (2, 1, 5, 3), + (3, 2), + ), + # Pillow grayscale + ( + Image.new("L", (5, 5), color=0), + (1, 1, 4, 4), + (3, 3), + ), + ], +) +def test_crop_image(image, xyxy, expected_size): + cropped = crop_image(image=image, xyxy=xyxy) + if isinstance(image, np.ndarray): + assert isinstance(cropped, np.ndarray) + assert cropped.shape[1] == expected_size[0] # width + assert cropped.shape[0] == expected_size[1] # height + else: + assert isinstance(cropped, Image.Image) + assert cropped.size == expected_size + + +@pytest.mark.parametrize( + "image, expected", + [ + # NumPy RGB + (np.zeros((4, 6, 3), dtype=np.uint8), (6, 4)), + # NumPy grayscale + (np.zeros((10, 20), dtype=np.uint8), (20, 10)), + # Pillow RGB + (Image.new("RGB", (6, 4), color=0), (6, 4)), + # Pillow grayscale + (Image.new("L", (20, 10), color=0), (20, 10)), + ], +) +def test_get_image_resolution_wh(image, expected): + resolution = get_image_resolution_wh(image) + assert resolution == expected