From cd91fed100221b453e97c34f2d714130ec860e94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Castro?= Date: Fri, 9 Feb 2024 13:34:52 -0300 Subject: [PATCH] Allow for giving a single score for the whole object --- norfair/tracker.py | 11 +++++++++-- norfair/utils.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/norfair/tracker.py b/norfair/tracker.py index 275eceee..5477d8f2 100644 --- a/norfair/tracker.py +++ b/norfair/tracker.py @@ -746,7 +746,7 @@ class Detection: Parameters ---------- points : np.ndarray - Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)` where n_dimensions is 2 or 3. + Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)`. scores : np.ndarray, optional An array of length `n_points` which assigns a score to each of the points defined in `points`. @@ -770,12 +770,19 @@ class Detection: def __init__( self, points: np.ndarray, - scores: np.ndarray = None, + scores: Union[float, int, np.ndarray] = None, data: Any = None, label: Hashable = None, embedding=None, ): self.points = validate_points(points) + + if isinstance(scores, np.ndarray): + assert len(scores) == len( + self.points + ), "scores should be a np.ndarray with it's length being equal to the amount of points." + else: + scores = np.zeros((len(points),)) + scores self.scores = scores self.data = data self.label = label diff --git a/norfair/utils.py b/norfair/utils.py index 98e0d903..10a64a93 100644 --- a/norfair/utils.py +++ b/norfair/utils.py @@ -20,7 +20,7 @@ def validate_points(points: np.ndarray) -> np.array: def raise_detection_error_message(points): message = "\n[red]INPUT ERROR:[/red]\n" - message += f"Each `Detection` object should have a property `points` of shape (num_of_points_to_track, 2), not {points.shape}. Check your `Detection` list creation code.\n" + message += f"Each `Detection` object should have a property `points` of shape (n_points, n_dimensions), not {points.shape}. Check your `Detection` list creation code.\n" message += "You can read the documentation for the `Detection` class here:\n" message += "https://tryolabs.github.io/norfair/reference/tracker/#norfair.tracker.Detection\n" raise ValueError(message)