Skip to content

Correct confusion matrix calculation-function evaluate_detection_batch #1853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
88 changes: 61 additions & 27 deletions supervision/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def from_detections(
# ])
```
"""

prediction_tensors = []
target_tensors = []
for prediction, target in zip(predictions, targets):
Expand Down Expand Up @@ -284,9 +283,28 @@ def evaluate_detection_batch(
"""
result_matrix = np.zeros((num_classes + 1, num_classes + 1))

# Filter predictions by confidence threshold
conf_idx = 5
confidence = predictions[:, conf_idx]
detection_batch_filtered = predictions[confidence > conf_threshold]
detection_batch_filtered = predictions[confidence >= conf_threshold]

if len(detection_batch_filtered) == 0:
# No detections pass confidence threshold - all GT are FN
class_id_idx = 4
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
for gt_class in true_classes:
result_matrix[gt_class, num_classes] += 1
return result_matrix

if len(targets) == 0:
# No ground truth - all detections are FP
class_id_idx = 4
detection_classes = np.array(
detection_batch_filtered[:, class_id_idx], dtype=np.int16
)
for det_class in detection_classes:
result_matrix[num_classes, det_class] += 1
return result_matrix

class_id_idx = 4
true_classes = np.array(targets[:, class_id_idx], dtype=np.int16)
Expand All @@ -296,35 +314,51 @@ def evaluate_detection_batch(
true_boxes = targets[:, :class_id_idx]
detection_boxes = detection_batch_filtered[:, :class_id_idx]

# Calculate IoU matrix
iou_batch = box_iou_batch(
boxes_true=true_boxes, boxes_detection=detection_boxes
)
matched_idx = np.asarray(iou_batch > iou_threshold).nonzero()

if matched_idx[0].shape[0]:
matches = np.stack(
(matched_idx[0], matched_idx[1], iou_batch[matched_idx]), axis=1
)
matches = ConfusionMatrix._drop_extra_matches(matches=matches)
else:
matches = np.zeros((0, 3))

matched_true_idx, matched_detection_idx, _ = matches.transpose().astype(
np.int16
)

for i, true_class_value in enumerate(true_classes):
j = matched_true_idx == i
if matches.shape[0] > 0 and sum(j) == 1:
result_matrix[
true_class_value, detection_classes[matched_detection_idx[j]]
] += 1 # TP
else:
result_matrix[true_class_value, num_classes] += 1 # FN

for i, detection_class_value in enumerate(detection_classes):
if not any(matched_detection_idx == i):
result_matrix[num_classes, detection_class_value] += 1 # FP
# Find all valid matches (IoU > threshold, regardless of class)
valid_matches = []
for gt_idx in range(len(true_classes)):
for det_idx in range(len(detection_classes)):
iou = iou_batch[gt_idx, det_idx]
if iou > iou_threshold:
gt_class = true_classes[gt_idx]
det_class = detection_classes[det_idx]
class_match = gt_class == det_class
valid_matches.append((gt_idx, det_idx, iou, class_match))

# Sort matches by class match first (True before False), then by IoU descending
# This prioritizes correct class predictions over higher IoU with wrong class
valid_matches.sort(key=lambda x: (x[3], x[2]), reverse=True)

# Greedily assign matches, ensuring each GT
# and detection is matched at most once
matched_gt_idx = set()
matched_det_idx = set()

for gt_idx, det_idx, iou, class_match in valid_matches:
if gt_idx not in matched_gt_idx and det_idx not in matched_det_idx:
# Valid spatial match - record the class prediction
gt_class = true_classes[gt_idx]
det_class = detection_classes[det_idx]

# This handles both correct classification (TP) and misclassification
result_matrix[gt_class, det_class] += 1
matched_gt_idx.add(gt_idx)
matched_det_idx.add(det_idx)

# Count unmatched ground truth as FN
for gt_idx, gt_class in enumerate(true_classes):
if gt_idx not in matched_gt_idx:
result_matrix[gt_class, num_classes] += 1

# Count unmatched detections as FP
for det_idx, det_class in enumerate(detection_classes):
if det_idx not in matched_det_idx:
result_matrix[num_classes, det_class] += 1

return result_matrix

Expand Down