Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,24 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:
class_id=np.arange(len(ultralytics_results)),
)

class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
class_names = np.array([ultralytics_results.names[i] for i in class_id])
return cls(
xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
confidence=ultralytics_results.boxes.conf.cpu().numpy(),
class_id=class_id,
mask=extract_ultralytics_masks(ultralytics_results),
tracker_id=ultralytics_results.boxes.id.int().cpu().numpy()
if ultralytics_results.boxes.id is not None
else None,
data={CLASS_NAME_DATA_FIELD: class_names},
)
if (
hasattr(ultralytics_results, "boxes")
and ultralytics_results.boxes is not None
):
class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
class_names = np.array([ultralytics_results.names[i] for i in class_id])
return cls(
xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
confidence=ultralytics_results.boxes.conf.cpu().numpy(),
class_id=class_id,
mask=extract_ultralytics_masks(ultralytics_results),
tracker_id=ultralytics_results.boxes.id.int().cpu().numpy()
if ultralytics_results.boxes.id is not None
else None,
data={CLASS_NAME_DATA_FIELD: class_names},
)

return cls.empty()

@classmethod
def from_yolo_nas(cls, yolo_nas_results) -> Detections:
Expand Down
63 changes: 35 additions & 28 deletions supervision/detection/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,60 +95,67 @@ def pad_boxes(xyxy: np.ndarray, px: int, py: int | None = None) -> np.ndarray:


def denormalize_boxes(
normalized_xyxy: np.ndarray,
xyxy: np.ndarray,
resolution_wh: tuple[int, int],
normalization_factor: float = 1.0,
) -> np.ndarray:
"""
Converts normalized bounding box coordinates to absolute pixel values.
Convert normalized bounding box coordinates to absolute pixel coordinates.

Multiplies each bounding box coordinate by image size and divides by
`normalization_factor`, mapping values from normalized `[0, normalization_factor]`
to absolute pixel values for a given resolution.

Args:
normalized_xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row
contains normalized coordinates in the format `(x_min, y_min, x_max, y_max)`,
with values between 0 and `normalization_factor`.
resolution_wh (Tuple[int, int]): A tuple `(width, height)` representing the
target image resolution.
normalization_factor (float, optional): The normalization range of the input
coordinates. Defaults to 1.0.
xyxy (`numpy.ndarray`): Normalized bounding boxes of shape `(N, 4)`,
where each row is `(x_min, y_min, x_max, y_max)`, values in
`[0, normalization_factor]`.
resolution_wh (`tuple[int, int]`): Target image resolution as `(width, height)`.
normalization_factor (`float`): Maximum value of input coordinate range.
Defaults to `1.0`.

Returns:
np.ndarray: An array of shape `(N, 4)` with absolute coordinates in
(`numpy.ndarray`): Array of shape `(N, 4)` with absolute coordinates in
`(x_min, y_min, x_max, y_max)` format.

Examples:
```python
import numpy as np
import supervision as sv

# Default normalization (0-1)
normalized_xyxy = np.array([
xyxy = np.array([
[0.1, 0.2, 0.5, 0.6],
[0.3, 0.4, 0.7, 0.8]
[0.3, 0.4, 0.7, 0.8],
[0.2, 0.1, 0.6, 0.5]
])
resolution_wh = (100, 200)
sv.denormalize_boxes(normalized_xyxy, resolution_wh)

sv.denormalize_boxes(xyxy, (1280, 720))
# array([
# [ 10., 40., 50., 120.],
# [ 30., 80., 70., 160.]
# [128., 144., 640., 432.],
# [384., 288., 896., 576.],
# [256., 72., 768., 360.]
# ])
```

# Custom normalization (0-100)
normalized_xyxy = np.array([
[10., 20., 50., 60.],
[30., 40., 70., 80.]
```
import numpy as np
import supervision as sv

xyxy = np.array([
[256., 128., 768., 640.]
])
sv.denormalize_boxes(normalized_xyxy, resolution_wh, normalization_factor=100.0)

sv.denormalize_boxes(xyxy, (1280, 720), normalization_factor=1024.0)
# array([
# [ 10., 40., 50., 120.],
# [ 30., 80., 70., 160.]
# [320., 90., 960., 450.]
# ])
```
""" # noqa E501 // docs
"""
width, height = resolution_wh
result = normalized_xyxy.copy()
result = xyxy.copy()

result[[0, 2]] = (result[[0, 2]] * width) / normalization_factor
result[[1, 3]] = (result[[1, 3]] * height) / normalization_factor
result[:, [0, 2]] = (result[:, [0, 2]] * width) / normalization_factor
result[:, [1, 3]] = (result[:, [1, 3]] * height) / normalization_factor

return result

Expand Down
40 changes: 18 additions & 22 deletions supervision/detection/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,26 +538,24 @@ def from_google_gemini_2_0(
return np.empty((0, 4)), None, np.empty((0,), dtype=str)

labels = []
boxes_list = []
xyxy = []

for item in data:
if "box_2d" not in item or "label" not in item:
continue
labels.append(item["label"])
box = item["box_2d"]
# Gemini bbox order is [y_min, x_min, y_max, x_max]
boxes_list.append(
denormalize_boxes(
np.array([box[1], box[0], box[3], box[2]]).astype(np.float64),
resolution_wh=(w, h),
normalization_factor=1000,
)
)
xyxy.append([box[1], box[0], box[3], box[2]])

if not boxes_list:
if len(xyxy) == 0:
return np.empty((0, 4)), None, np.empty((0,), dtype=str)

xyxy = np.array(boxes_list)
xyxy = denormalize_boxes(
np.array(xyxy, dtype=np.float64),
resolution_wh=(w, h),
normalization_factor=1000,
)
class_name = np.array(labels)
class_id = None

Expand Down Expand Up @@ -649,10 +647,10 @@ def from_google_gemini_2_5(
box = item["box_2d"]
# Gemini bbox order is [y_min, x_min, y_max, x_max]
absolute_bbox = denormalize_boxes(
np.array([box[1], box[0], box[3], box[2]]).astype(np.float64),
np.array([[box[1], box[0], box[3], box[2]]]).astype(np.float64),
resolution_wh=(w, h),
normalization_factor=1000,
)
)[0]
boxes_list.append(absolute_bbox)

if "mask" in item:
Expand Down Expand Up @@ -735,7 +733,7 @@ def from_google_gemini_2_5(
def from_moondream(
result: dict,
resolution_wh: tuple[int, int],
) -> tuple[np.ndarray]:
) -> np.ndarray:
"""
Parse and scale bounding boxes from moondream JSON output.

Expand Down Expand Up @@ -773,7 +771,7 @@ def from_moondream(
if "objects" not in result or not isinstance(result["objects"], list):
return np.empty((0, 4), dtype=float)

denormalize_xyxy = []
xyxy = []

for item in result["objects"]:
if not all(k in item for k in ["x_min", "y_min", "x_max", "y_max"]):
Expand All @@ -784,14 +782,12 @@ def from_moondream(
x_max = item["x_max"]
y_max = item["y_max"]

denormalize_xyxy.append(
denormalize_boxes(
np.array([x_min, y_min, x_max, y_max]).astype(np.float64),
resolution_wh=(w, h),
)
)
xyxy.append([x_min, y_min, x_max, y_max])

if not denormalize_xyxy:
if len(xyxy) == 0:
return np.empty((0, 4))

return np.array(denormalize_xyxy, dtype=float)
return denormalize_boxes(
np.array(xyxy).astype(np.float64),
resolution_wh=(w, h),
)
92 changes: 91 additions & 1 deletion test/detection/utils/test_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np
import pytest

from supervision.detection.utils.boxes import clip_boxes, move_boxes, scale_boxes
from supervision.detection.utils.boxes import (
clip_boxes,
denormalize_boxes,
move_boxes,
scale_boxes,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -142,3 +147,88 @@ def test_scale_boxes(
with exception:
result = scale_boxes(xyxy=xyxy, factor=factor)
assert np.array_equal(result, expected_result)


@pytest.mark.parametrize(
"xyxy, resolution_wh, normalization_factor, expected_result, exception",
[
(
np.empty(shape=(0, 4)),
(1280, 720),
1.0,
np.empty(shape=(0, 4)),
DoesNotRaise(),
), # empty array
(
np.array([[0.1, 0.2, 0.5, 0.6]]),
(1280, 720),
1.0,
np.array([[128.0, 144.0, 640.0, 432.0]]),
DoesNotRaise(),
), # single box with default normalization
(
np.array([[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8]]),
(1280, 720),
1.0,
np.array([[128.0, 144.0, 640.0, 432.0], [384.0, 288.0, 896.0, 576.0]]),
DoesNotRaise(),
), # two boxes with default normalization
(
np.array(
[[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8], [0.2, 0.1, 0.6, 0.5]]
),
(1280, 720),
1.0,
np.array(
[
[128.0, 144.0, 640.0, 432.0],
[384.0, 288.0, 896.0, 576.0],
[256.0, 72.0, 768.0, 360.0],
]
),
DoesNotRaise(),
), # three boxes - regression test for issue #1959
(
np.array([[10.0, 20.0, 50.0, 60.0]]),
(100, 200),
100.0,
np.array([[10.0, 40.0, 50.0, 120.0]]),
DoesNotRaise(),
), # single box with custom normalization factor
(
np.array([[10.0, 20.0, 50.0, 60.0], [30.0, 40.0, 70.0, 80.0]]),
(100, 200),
100.0,
np.array([[10.0, 40.0, 50.0, 120.0], [30.0, 80.0, 70.0, 160.0]]),
DoesNotRaise(),
), # two boxes with custom normalization factor
(
np.array([[0.0, 0.0, 1.0, 1.0]]),
(1920, 1080),
1.0,
np.array([[0.0, 0.0, 1920.0, 1080.0]]),
DoesNotRaise(),
), # full frame box
(
np.array([[0.5, 0.5, 0.5, 0.5]]),
(640, 480),
1.0,
np.array([[320.0, 240.0, 320.0, 240.0]]),
DoesNotRaise(),
), # zero-area box (point)
],
)
def test_denormalize_boxes(
xyxy: np.ndarray,
resolution_wh: tuple[int, int],
normalization_factor: float,
expected_result: np.ndarray,
exception: Exception,
) -> None:
with exception:
result = denormalize_boxes(
xyxy=xyxy,
resolution_wh=resolution_wh,
normalization_factor=normalization_factor,
)
assert np.allclose(result, expected_result)