Skip to content

Support for IOS Matching Metric. Introduced the mask_non_max_merge function for handling non-maximum merging of masks #1774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e4b5a8e
### feat(detection): Support for IOS Matching Metric**
SunHao-AI Jan 9, 2025
9cd6549
feat(detection): Support for IOS Matching Metric
SunHao-AI Jan 9, 2025
8d29796
Merge remote-tracking branch 'origin/develop' into develop
SunHao-AI Jan 9, 2025
4679334
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jan 9, 2025
39766f4
Merge branch 'roboflow:develop' into develop
SunHao-AI Jan 16, 2025
9a9324a
Merge branch 'roboflow:develop' into develop
SunHao-AI Jun 25, 2025
646ac83
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jun 25, 2025
d002858
ci: 注释掉创建 GitHub App token 的步骤
SunHao-AI Jun 25, 2025
eb170bc
ci:取消注释 GitHub App token 步骤
SunHao-AI Jun 25, 2025
13e857c
refactor(detection): 重构合并检测对象的逻辑
SunHao-AI Jul 7, 2025
0b120f1
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jul 7, 2025
f47f032
Merge branch 'roboflow:develop' into develop
SunHao-AI Jul 7, 2025
50c0d54
Merge branch 'develop' into develop
soumik12345 Jul 14, 2025
74d5f0f
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jul 14, 2025
9517987
chore: make pre-commit happy
soumik12345 Jul 14, 2025
fb16bd9
chore: change match_metric to overlap_metric
soumik12345 Jul 15, 2025
42e9c5f
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jul 15, 2025
7849831
chore: make pre-commit happy
soumik12345 Jul 15, 2025
3fed5f9
update: docs
soumik12345 Jul 15, 2025
fd25f4b
add: docs for mask_non_max_merge
soumik12345 Jul 15, 2025
ebf4b93
add: test for mask_non_max_merge
soumik12345 Jul 15, 2025
b3024e7
chore: make pre-commit happy
soumik12345 Jul 15, 2025
007d8f5
chore: make enum comparisons
soumik12345 Jul 15, 2025
ae77077
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jul 15, 2025
4977fd6
add: test_mask_non_max_merge
soumik12345 Jul 15, 2025
9a38ebd
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Jul 15, 2025
7ea8c2e
chore: remove excessive comments in group_overlapping_masks
soumik12345 Jul 15, 2025
c11a7de
chore: refactor docs for
soumik12345 Jul 15, 2025
84c9a34
update: docstring
soumik12345 Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from supervision.detection.overlap_filter import (
box_non_max_merge,
box_non_max_suppression,
mask_non_max_merge,
mask_non_max_suppression,
)
from supervision.detection.tools.transformers import (
Expand All @@ -27,6 +28,7 @@
get_data_item,
is_data_equal,
is_metadata_equal,
mask_iou_batch,
mask_to_xyxy,
merge_data,
merge_metadata,
Expand Down Expand Up @@ -1321,7 +1323,10 @@ def box_area(self) -> np.ndarray:
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])

def with_nms(
self, threshold: float = 0.5, class_agnostic: bool = False
self,
threshold: float = 0.5,
class_agnostic: bool = False,
match_metric: str = "IOU",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match metric should be converted into OverlapMetric enum, added to supervision.detection.overlap_filter and documented in a similar way as OverlapFilter.

all match_metric: str = "IOU" should be converted into overlap_metric: OverlapMetric = OverlapMetric.IOU

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OverlapMetric should be added to double_detection_filter docs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even go as far as providing math formulas for IOU and IOS

) -> Detections:
"""
Performs non-max suppression on detection set. If the detections result
Expand All @@ -1334,6 +1339,8 @@ def with_nms(
class_agnostic (bool): Whether to perform class-agnostic
non-maximum suppression. If True, the class_id of each detection
will be ignored. Defaults to False.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
Detections: A new Detections object containing the subset of detections
Expand Down Expand Up @@ -1367,17 +1374,25 @@ def with_nms(

if self.mask is not None:
indices = mask_non_max_suppression(
predictions=predictions, masks=self.mask, iou_threshold=threshold
predictions=predictions,
masks=self.mask,
iou_threshold=threshold,
match_metric=match_metric,
)
else:
indices = box_non_max_suppression(
predictions=predictions, iou_threshold=threshold
predictions=predictions,
iou_threshold=threshold,
match_metric=match_metric,
)

return self[indices]

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
self,
threshold: float = 0.5,
class_agnostic: bool = False,
match_metric: str = "IOU",
) -> Detections:
"""
Perform non-maximum merging on the current set of object detections.
Expand All @@ -1388,6 +1403,8 @@ def with_nmm(
class_agnostic (bool): Whether to perform class-agnostic
non-maximum merging. If True, the class_id of each detection
will be ignored. Defaults to False.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
Detections: A new Detections object containing the subset of detections
Expand Down Expand Up @@ -1421,15 +1438,25 @@ def with_nmm(
)
)

merge_groups = box_non_max_merge(
predictions=predictions, iou_threshold=threshold
)
if self.mask is not None:
merge_groups = mask_non_max_merge(
predictions=predictions,
masks=self.mask,
iou_threshold=threshold,
match_metric=match_metric,
)
else:
merge_groups = box_non_max_merge(
predictions=predictions,
iou_threshold=threshold,
match_metric=match_metric,
)

result = []
for merge_group in merge_groups:
unmerged_detections = [self[i] for i in merge_group]
merged_detections = merge_inner_detections_objects(
unmerged_detections, threshold
unmerged_detections, threshold, match_metric
)
result.append(merged_detections)

Expand Down Expand Up @@ -1529,7 +1556,7 @@ def merge_inner_detection_object_pair(


def merge_inner_detections_objects(
detections: List[Detections], threshold=0.5
detections: List[Detections], threshold=0.5, match_metric: str = "IOU"
) -> Detections:
"""
Given N detections each of length 1 (exactly one object inside), combine them into a
Expand All @@ -1541,8 +1568,11 @@ def merge_inner_detections_objects(
"""
detections_1 = detections[0]
for detections_2 in detections[1:]:
box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0]
if box_iou < threshold:
if detections_1.mask is not None and detections_2.mask is not None:
iou = mask_iou_batch(detections_1.mask, detections_2.mask, match_metric)[0]
else:
iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy, match_metric)[0]
if iou < threshold:
break
detections_1 = merge_inner_detection_object_pair(detections_1, detections_2)
return detections_1
Expand Down
136 changes: 129 additions & 7 deletions supervision/detection/overlap_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def mask_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
match_metric: str = "IOU",
mask_dimension: int = 640,
) -> np.ndarray:
"""
Expand All @@ -57,6 +58,8 @@ def mask_non_max_suppression(
dimensions of each mask.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values. Defaults to 640.

Expand All @@ -81,7 +84,7 @@ def mask_non_max_suppression(
predictions = predictions[sort_index]
masks = masks[sort_index]
masks_resized = resize_masks(masks, mask_dimension)
ious = mask_iou_batch(masks_resized, masks_resized)
ious = mask_iou_batch(masks_resized, masks_resized, match_metric)
categories = predictions[:, 5]

keep = np.ones(rows, dtype=bool)
Expand All @@ -94,7 +97,7 @@ def mask_non_max_suppression(


def box_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
predictions: np.ndarray, iou_threshold: float = 0.5, match_metric: str = "IOU"
) -> np.ndarray:
"""
Perform Non-Maximum Suppression (NMS) on object detection predictions.
Expand All @@ -105,6 +108,8 @@ def box_non_max_suppression(
or `(x_min, y_min, x_max, y_max, score, class)`.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
np.ndarray: A boolean array indicating which predictions to keep after n
Expand All @@ -130,7 +135,7 @@ def box_non_max_suppression(

boxes = predictions[:, :4]
categories = predictions[:, 5]
ious = box_iou_batch(boxes, boxes)
ious = box_iou_batch(boxes, boxes, match_metric)
ious = ious - np.eye(rows)

keep = np.ones(rows, dtype=bool)
Expand All @@ -148,7 +153,9 @@ def box_non_max_suppression(


def group_overlapping_boxes(
predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
predictions: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging to avoid detecting too many
Expand All @@ -160,6 +167,8 @@ def group_overlapping_boxes(
and the confidence scores.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Expand All @@ -179,7 +188,9 @@ def group_overlapping_boxes(
break

merge_candidate = np.expand_dims(predictions[idx], axis=0)
ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
ious = box_iou_batch(
predictions[order][:, :4], merge_candidate[:, :4], match_metric
)
ious = ious.flatten()

above_threshold = ious >= iou_threshold
Expand All @@ -189,9 +200,71 @@ def group_overlapping_boxes(
return merge_groups


def mask_non_max_merge(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
mask_dimension: int = 640,
match_metric: str = "IOU",
) -> np.ndarray:
"""
Perform Non-Maximum Merging (NMM) on segmentation predictions.

Args:
predictions (np.ndarray): A 2D array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
`(N, 6)`, where N is the number of predictions.
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each mask.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values. Defaults to 640.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
np.ndarray: A boolean array indicating which predictions to keep after
non-maximum suppression.

Raises:
AssertionError: If `iou_threshold` is not within the closed
range from `0` to `1`.
"""
masks_resized = resize_masks(masks, mask_dimension)
if predictions.shape[1] == 5:
return group_overlapping_masks(
predictions, masks_resized, iou_threshold, match_metric
)

category_ids = predictions[:, 5]
merge_groups = []
for category_id in np.unique(category_ids):
curr_indices = np.where(category_ids == category_id)[0]
merge_class_groups = group_overlapping_masks(
predictions[curr_indices],
masks_resized[curr_indices],
iou_threshold,
match_metric,
)

for merge_class_group in merge_class_groups:
merge_groups.append(curr_indices[merge_class_group].tolist())

for merge_group in merge_groups:
if len(merge_group) == 0:
raise ValueError(
f"Empty group detected when non-max-merging detections: {merge_groups}"
)
return merge_groups


def box_non_max_merge(
predictions: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging per category to avoid detecting
Expand All @@ -204,20 +277,22 @@ def box_non_max_merge(
detections of different classes to be merged.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Each group may have 1 or more elements.
"""
if predictions.shape[1] == 5:
return group_overlapping_boxes(predictions, iou_threshold)
return group_overlapping_boxes(predictions, iou_threshold, match_metric)

category_ids = predictions[:, 5]
merge_groups = []
for category_id in np.unique(category_ids):
curr_indices = np.where(category_ids == category_id)[0]
merge_class_groups = group_overlapping_boxes(
predictions[curr_indices], iou_threshold
predictions[curr_indices], iou_threshold, match_metric
)

for merge_class_group in merge_class_groups:
Expand All @@ -231,6 +306,53 @@ def box_non_max_merge(
return merge_groups


def group_overlapping_masks(
predictions: npt.NDArray[np.float64],
masks: npt.NDArray[np.float64],
iou_threshold: float = 0.5,
match_metric: str = "IOU",
) -> List[List[int]]:
"""
Apply greedy version of non-maximum merging to avoid detecting too many

Args:
predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
and the confidence scores.
masks (npt.NDArray[np.float64]): A 3D array of binary masks corresponding to the predictions.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
match_metric (str): Metric used for matching detections in slices.
"IOU" or "IOS". Defaults "IOU".

Returns:
List[List[int]]: Groups of prediction indices be merged.
Each group may have 1 or more elements.
"""
merge_groups: List[List[int]] = []

scores = predictions[:, 4]
order = scores.argsort()

while len(order) > 0:
idx = int(order[-1])

order = order[:-1]
if len(order) == 0:
merge_groups.append([idx])
break

merge_candidate = np.expand_dims(masks[idx], axis=0)
ious = mask_iou_batch(masks[order], merge_candidate, match_metric)
ious = ious.flatten()

above_threshold = ious >= iou_threshold
merge_group = [idx, *np.flip(order[above_threshold]).tolist()]
merge_groups.append(merge_group)
order = order[~above_threshold]
return merge_groups


class OverlapFilter(Enum):
"""
Enum specifying the strategy for filtering overlapping detections.
Expand Down
Loading