diff --git a/docs/detection/utils/masks.md b/docs/detection/utils/masks.md index 9e53a6baa..99097bef6 100644 --- a/docs/detection/utils/masks.md +++ b/docs/detection/utils/masks.md @@ -22,3 +22,9 @@ status: new :::supervision.detection.utils.masks.contains_multiple_segments + +
+ +:::supervision.detection.utils.masks.filter_segments_by_distance diff --git a/supervision/__init__.py b/supervision/__init__.py index a70dd20fe..ccd272930 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -87,6 +87,7 @@ calculate_masks_centroids, contains_holes, contains_multiple_segments, + filter_segments_by_distance, move_masks, ) from supervision.detection.utils.polygons import ( @@ -219,6 +220,7 @@ "draw_text", "edit_distance", "filter_polygons_by_area", + "filter_segments_by_distance", "fuzzy_match_index", "get_coco_class_index_mapping", "get_polygon_center", diff --git a/supervision/detection/utils/masks.py b/supervision/detection/utils/masks.py index c5cfee017..35896f20c 100644 --- a/supervision/detection/utils/masks.py +++ b/supervision/detection/utils/masks.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + import cv2 import numpy as np import numpy.typing as npt @@ -260,3 +262,139 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: resized_masks = masks[:, yv, xv] return resized_masks.reshape(masks.shape[0], new_height, new_width) + + +def filter_segments_by_distance( + mask: npt.NDArray[np.bool_], + absolute_distance: float | None = 100.0, + relative_distance: float | None = None, + connectivity: int = 8, + mode: Literal["edge", "centroid"] = "edge", +) -> npt.NDArray[np.bool_]: + """ + Keep the largest connected component and any other components within a distance + threshold. + + Distance can be absolute in pixels or relative to the image diagonal. + + Args: + mask: Boolean mask HxW. + absolute_distance: Max allowed distance in pixels to the main component. + Ignored if `relative_distance` is provided. + relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2). + connectivity: Defines which neighboring pixels are considered connected. + - 4-connectedness: Only orthogonal neighbors. + ``` + [ ][X][ ] + [X][O][X] + [ ][X][ ] + ``` + - 8-connectedness: Includes diagonal neighbors. + ``` + [X][X][X] + [X][O][X] + [X][X][X] + ``` + Default is 8. + mode: Defines how distance between components is measured. + - "edge": Uses distance between nearest edges (via distance transform). + - "centroid": Uses distance between component centroids. + + Returns: + Boolean mask after filtering. + + Examples: + ```python + import numpy as np + import supervision as sv + + mask = np.array([ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], dtype=bool) + + sv.filter_segments_by_distance( + mask, + absolute_distance=2, + mode="edge", + connectivity=8 + ).astype(int) + + # np.array([ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # ], dtype=bool) + + # The nearby 2×2 block at columns 6–7 is kept because its edge distance + # is within 2 pixels. The distant block at columns 9-10 is removed. + ``` + """ # noqa E501 // docs + if mask.dtype != bool: + raise TypeError("mask must be boolean") + + height, width = mask.shape + if not np.any(mask): + return mask.copy() + + image = mask.astype(np.uint8) + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + image, connectivity=connectivity + ) + + if num_labels <= 1: + return mask.copy() + + areas = stats[1:, cv2.CC_STAT_AREA] + main_label = 1 + int(np.argmax(areas)) + + if relative_distance is not None: + diagonal = float(np.hypot(height, width)) + threshold = float(relative_distance) * diagonal + else: + threshold = float(absolute_distance) + + keep_labels = np.zeros(num_labels, dtype=bool) + keep_labels[main_label] = True + + if mode == "centroid": + differences = centroids[1:] - centroids[main_label] + distances = np.sqrt(np.sum(differences**2, axis=1)) + nearby = 1 + np.where(distances <= threshold)[0] + keep_labels[nearby] = True + elif mode == "edge": + main_mask = (labels == main_label).astype(np.uint8) + inverse = 1 - main_mask + distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3) + for label in range(1, num_labels): + if label == main_label: + continue + component = labels == label + if not np.any(component): + continue + min_distance = float(distance_transform[component].min()) + if min_distance <= threshold: + keep_labels[label] = True + else: + raise ValueError("mode must be 'edge' or 'centroid'") + + return keep_labels[labels] diff --git a/test/detection/utils/test_masks.py b/test/detection/utils/test_masks.py index 2097f6082..b41f208ed 100644 --- a/test/detection/utils/test_masks.py +++ b/test/detection/utils/test_masks.py @@ -10,6 +10,7 @@ calculate_masks_centroids, contains_holes, contains_multiple_segments, + filter_segments_by_distance, move_masks, ) @@ -500,3 +501,228 @@ def test_contains_multiple_segments( with exception: result = contains_multiple_segments(mask=mask, connectivity=connectivity) assert result == expected_result + + +@pytest.mark.parametrize( + "mask, connectivity, mode, absolute_distance, relative_distance, expected_result, exception", # noqa: E501 + [ + # single component, unchanged + ( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, + None, + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # two components, edge distance 2, kept with abs=1 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, + None, + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # centroid mode, far centroids, dropped with small relative threshold + ( + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + 8, + "centroid", + None, + 0.3, # diagonal ~8.49, threshold ~2.55, centroid gap ~4.24 + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # centroid mode, larger relative threshold, kept + ( + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + 8, + "centroid", + None, + 0.6, # diagonal ~8.49, threshold ~5.09, centroid gap ~4.24 + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # empty mask + ( + np.zeros((4, 4), dtype=bool), + 4, + "edge", + 2.0, + None, + np.zeros((4, 4), dtype=bool), + DoesNotRaise(), + ), + # full mask + ( + np.ones((4, 4), dtype=bool), + 8, + "centroid", + None, + 0.2, + np.ones((4, 4), dtype=bool), + DoesNotRaise(), + ), + # two components, pixel distance = 2, kept with abs=2 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, # was 1.0 + None, + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # two components, pixel distance = 3, dropped with abs=2 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, # keep threshold below 3 so the right blob is removed + None, + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ], +) +def test_filter_segments_by_distance_sweep( + mask: npt.NDArray, + connectivity: int, + mode: str, + absolute_distance: float | None, + relative_distance: float | None, + expected_result: npt.NDArray | None, + exception: Exception, +) -> None: + with exception: + result = filter_segments_by_distance( + mask=mask, + connectivity=connectivity, + mode=mode, # type: ignore[arg-type] + absolute_distance=absolute_distance, + relative_distance=relative_distance, + ) + assert np.array_equal(result, expected_result)