diff --git a/supervision/detection/line_counter.py b/supervision/detection/line_counter.py index ea97fcb83..c668354eb 100644 --- a/supervision/detection/line_counter.py +++ b/supervision/detection/line_counter.py @@ -27,15 +27,20 @@ def __init__(self, start: Point, end: Point): self.in_count: int = 0 self.out_count: int = 0 - def trigger(self, detections: Detections): + def trigger(self, detections: Detections) -> np.ndarray: """ Update the in_count and out_count for the detections that cross the line. Attributes: detections (Detections): The detections for which to update the counts. + Returns: + np.ndarray: A boolean array indicating + which detection has crossed the line on the either sides """ - for xyxy, _, confidence, class_id, tracker_id in detections: + crossed = np.full(len(detections), False) + + for i, (xyxy, _, confidence, class_id, tracker_id) in enumerate(detections): # handle detections with no tracker_id if tracker_id is None: continue @@ -67,8 +72,13 @@ def trigger(self, detections: Detections): self.tracker_state[tracker_id] = tracker_state if tracker_state: self.in_count += 1 + crossed[i] = True + else: self.out_count += 1 + crossed[i] = True + + return crossed class LineZoneAnnotator: