Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion supervision/draw/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def blue(cls) -> Color:

@dataclass
class ColorPalette:
colors: List[Color] = field(default_factory=lambda: DEFAULT_COLOR_PALETTE)
colors: List[Color] = field(
default_factory=lambda: [
Color.from_hex(color_hex) for color_hex in DEFAULT_COLOR_PALETTE
]
)

@classmethod
def from_hex(cls, color_hex_list: List[str]):
Expand Down
File renamed without changes.
File renamed without changes.
205 changes: 205 additions & 0 deletions supervision/tools/detections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import List, Optional, Union

import cv2
import numpy as np

from supervision.draw.color import Color, ColorPalette


class Detections:
def __init__(
self,
xyxy: np.ndarray,
confidence: np.ndarray,
class_id: np.ndarray,
tracker_id: Optional[np.ndarray] = None,
):
"""
Data class containing information about the detections in a video frame.

:param xyxy: np.ndarray : An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]
:param confidence: np.ndarray : An array of shape (n,) containing the confidence scores of the detections.
:param class_id: np.ndarray : An array of shape (n,) containing the class ids of the detections.
:param tracker_id: Optional[np.ndarray] : An array of shape (n,) containing the tracker ids of the detections.
"""
self.xyxy: np.ndarray = xyxy
self.confidence: np.ndarray = confidence
self.class_id: np.ndarray = class_id
self.tracker_id: Optional[np.ndarray] = tracker_id

n = len(self.xyxy)
validators = [
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
(isinstance(self.confidence, np.ndarray) and self.confidence.shape == (n,)),
(isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)),
self.tracker_id is None
or (
isinstance(self.tracker_id, np.ndarray)
and self.tracker_id.shape == (n,)
),
]
if not all(validators):
raise ValueError(
"xyxy must be 2d np.ndarray with (n, 4) shape, "
"confidence must be 1d np.ndarray with (n,) shape, "
"class_id must be 1d np.ndarray with (n,) shape, "
"tracker_id must be None or 1d np.ndarray with (n,) shape"
)

def __len__(self):
"""
Returns the number of detections in the Detections object.
"""
return len(self.xyxy)

def __iter__(self):
"""
Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection.
"""
for i in range(len(self.xyxy)):
yield (
self.xyxy[i],
self.confidence[i],
self.class_id[i],
self.tracker_id[i] if self.tracker_id is not None else None,
)

@classmethod
def from_yolov5(cls, yolov5_output: np.ndarray):
"""
Creates a Detections instance from a YOLOv5 output tensor

:param yolov5_output: np.ndarray : The output tensor from YOLOv5
:return: Detections : A Detections instance representing the detections in the frame

Example:
detections = Detections.from_yolov5(yolov5_output)
"""
xyxy = yolov5_output[:, :4]
confidence = yolov5_output[:, 4]
class_id = yolov5_output[:, 5].astype(int)
return cls(xyxy, confidence, class_id)

def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]:
"""
Filter the detections by applying a mask

:param mask: np.ndarray : A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections
:param inplace: bool : If True, the original data will be modified and self will be returned.
:return: Optional[np.ndarray] : A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise.
"""
if inplace:
self.xyxy = self.xyxy[mask]
self.confidence = self.confidence[mask]
self.class_id = self.class_id[mask]
self.tracker_id = (
self.tracker_id[mask] if self.tracker_id is not None else None
)
return self
else:
return Detections(
xyxy=self.xyxy[mask],
confidence=self.confidence[mask],
class_id=self.class_id[mask],
tracker_id=self.tracker_id[mask]
if self.tracker_id is not None
else None,
)


class BoxAnnotator:
def __init__(
self,
color: Union[Color, ColorPalette],
thickness: int = 2,
text_color: Color = Color.black(),
text_scale: float = 0.5,
text_thickness: int = 1,
text_padding: int = 10,
):
"""
A class for drawing bounding boxes on an image using detections provided.

:param color: Union[Color, ColorPalette] : The color to draw the bounding box, can be a single color or a color palette
:param thickness: int : The thickness of the bounding box lines, default is 2
:param text_color: Color : The color of the text on the bounding box, default is white
:param text_scale: float : The scale of the text on the bounding box, default is 0.5
:param text_thickness: int : The thickness of the text on the bounding box, default is 1
:param text_padding: int : The padding around the text on the bounding box, default is 5
"""
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding

def annotate(
self,
frame: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.

:param frame: np.ndarray : The image on which the bounding boxes will be drawn
:param detections: Detections : The detections for which the bounding boxes will be drawn
:param labels: Optional[List[str]] : An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label.
:return: np.ndarray : The image with the bounding boxes drawn on it
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
color = (
self.color.by_idx(class_id)
if isinstance(self.color, ColorPalette)
else self.color
)
text = (
f"{confidence:0.2f}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)

x1, y1, x2, y2 = xyxy.astype(int)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]

text_x = x1 + self.text_padding
text_y = y1 - self.text_padding

text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height

text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1

cv2.rectangle(
img=frame,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
cv2.rectangle(
img=frame,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
cv2.putText(
img=frame,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
color=self.text_color.as_rgb(),
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return frame
Empty file added test/geometry/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from supervision.commons.dataclasses import Vector, Point
from supervision.geometry.dataclasses import Vector, Point


@pytest.mark.parametrize(
Expand Down