Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/detection/utils/masks.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ status: new
</div>

:::supervision.detection.utils.masks.contains_multiple_segments

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.masks.filter_segments_by_distance">filter_segments_by_distance</a></h2>
</div>

:::supervision.detection.utils.masks.filter_segments_by_distance
2 changes: 2 additions & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
calculate_masks_centroids,
contains_holes,
contains_multiple_segments,
filter_segments_by_distance,
move_masks,
)
from supervision.detection.utils.polygons import (
Expand Down Expand Up @@ -219,6 +220,7 @@
"draw_text",
"edit_distance",
"filter_polygons_by_area",
"filter_segments_by_distance",
"fuzzy_match_index",
"get_coco_class_index_mapping",
"get_polygon_center",
Expand Down
138 changes: 138 additions & 0 deletions supervision/detection/utils/masks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Literal

import cv2
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -260,3 +262,139 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
resized_masks = masks[:, yv, xv]

return resized_masks.reshape(masks.shape[0], new_height, new_width)


def filter_segments_by_distance(
mask: npt.NDArray[np.bool_],
absolute_distance: float | None = 100.0,
relative_distance: float | None = None,
connectivity: int = 8,
mode: Literal["edge", "centroid"] = "edge",
) -> npt.NDArray[np.bool_]:
"""
Keep the largest connected component and any other components within a distance
threshold.

Distance can be absolute in pixels or relative to the image diagonal.

Args:
mask: Boolean mask HxW.
absolute_distance: Max allowed distance in pixels to the main component.
Ignored if `relative_distance` is provided.
relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2).
connectivity: Defines which neighboring pixels are considered connected.
- 4-connectedness: Only orthogonal neighbors.
```
[ ][X][ ]
[X][O][X]
[ ][X][ ]
```
- 8-connectedness: Includes diagonal neighbors.
```
[X][X][X]
[X][O][X]
[X][X][X]
```
Default is 8.
mode: Defines how distance between components is measured.
- "edge": Uses distance between nearest edges (via distance transform).
- "centroid": Uses distance between component centroids.

Returns:
Boolean mask after filtering.

Examples:
```python
import numpy as np
import supervision as sv

mask = np.array([
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
], dtype=bool)

sv.filter_segments_by_distance(
mask,
absolute_distance=2,
mode="edge",
connectivity=8
).astype(int)

# np.array([
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# ], dtype=bool)

# The nearby 2×2 block at columns 6–7 is kept because its edge distance
# is within 2 pixels. The distant block at columns 9-10 is removed.
```
""" # noqa E501 // docs
if mask.dtype != bool:
raise TypeError("mask must be boolean")

height, width = mask.shape
if not np.any(mask):
return mask.copy()

image = mask.astype(np.uint8)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
image, connectivity=connectivity
)

if num_labels <= 1:
return mask.copy()

areas = stats[1:, cv2.CC_STAT_AREA]
main_label = 1 + int(np.argmax(areas))

if relative_distance is not None:
diagonal = float(np.hypot(height, width))
threshold = float(relative_distance) * diagonal
else:
threshold = float(absolute_distance)

keep_labels = np.zeros(num_labels, dtype=bool)
keep_labels[main_label] = True

if mode == "centroid":
differences = centroids[1:] - centroids[main_label]
distances = np.sqrt(np.sum(differences**2, axis=1))
nearby = 1 + np.where(distances <= threshold)[0]
keep_labels[nearby] = True
elif mode == "edge":
main_mask = (labels == main_label).astype(np.uint8)
inverse = 1 - main_mask
distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3)
for label in range(1, num_labels):
if label == main_label:
continue
component = labels == label
if not np.any(component):
continue
min_distance = float(distance_transform[component].min())
if min_distance <= threshold:
keep_labels[label] = True
else:
raise ValueError("mode must be 'edge' or 'centroid'")

return keep_labels[labels]
Loading