In [1]:
import torch
import numpy as np
import cv2
import time
from ultralytics import RTDETR, YOLO
import supervision as sv
from supervision import BoundingBoxAnnotator

In [2]:
class YOLOClass:
    def __init__(self, capture_index):
        self.capture_index = capture_index
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
        print("Using device:", self.device)
        # self.model = REDETR('rtdetr-l.pt')
        # self.model = YOLO('yolov8n.pt')
        self.model = YOLO('runs/detect/train27/weights/last.pt')
        self.CLASS_NAMES_DITC = self.model.model.names
        print("Classes: ", self.CLASS_NAMES_DITC)
        self.box_annotator = sv.BoundingBoxAnnotator()
        self.label_annotator = sv.LabelAnnotator()

    def plot_bboxes(self, results, frame):
        # Extract the results
        boxes = results[0].boxes.cpu().numpy()
        class_id = boxes.cls
        conf = boxes.conf
        xyxy = boxes.xyxy
        
        class_id = class_id.astype(np.int32)

        # Filter out detections for person objects
        # person_indices = np.where(class_id == 0)[0]
        # person_boxes = xyxy[person_indices]
        # person_confidences = conf[person_indices]
        # person_class_ids = class_id[person_indices]

        # detections = sv.Detections(
        #     xyxy=person_boxes,
        #     confidence=person_confidences,
        #     class_id=person_class_ids
        # )

        detections = sv.Detections(
            xyxy=xyxy,
            confidence=conf,
            class_id=class_id
        )

        frame = self.label_annotator.annotate(
            frame,
            detections=detections
        )

        self.labels = [f"{self.CLASS_NAMES_DITC[class_id]} {confidence:.2f}" for xyxy, mask, confidence, class_id, tracker_id, data in detections]

        frame = self.box_annotator.annotate(frame, detections)

        return frame

        # Initialize colors for different classes
        # colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]  # Example colors for 3 classes

        # for i in range(len(xyxy)):
        #     x1, y1, x2, y2 = int(xyxy[i][0]), int(xyxy[i][1]), int(xyxy[i][2]), int(xyxy[i][3])
        #     confidence = conf[i]
        #     class_name = self.CLASS_NAMES_DITC[class_id[i]]

        # # Draw rectangle
        # cv2.rectangle(frame, (x1, y1), (x2, y2), colors[class_id[i]], 2)

        # # Display class name and confidence
        # cv2.putText(frame, f"{class_name} {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[class_id[i]], 2)

        # return frame
    
    def __call__(self):
        cap = cv2.VideoCapture(self.capture_index)
        assert cap.isOpened(), f"Failed to open camera"
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

        while cap.isOpened():
            start_time = time.perf_counter()
            ret, frame = cap.read()
            results = self.model.predict(frame)
            frame = self.plot_bboxes(results, frame)
            end_time = time.perf_counter()
            fps = 1 / (end_time - start_time)
            print(f"FPS: {fps:.2f}")

            cv2.putText(frame, f"{fps:.2f} FPS", (20, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)

            cv2.imshow('DETR', frame)

            if cv2.waitKey(1) == ord('q'):
                break

        cap.release()
        cv2.destroyAllWindows()

    def train(self, config_file_path, epochs=1, imagesize=None, device="cpu", verbose=False):
        self.model.train(data=config_file_path, epochs=epochs, imgsz=imagesize, device=device, batch=-1, verbose=verbose)

In [3]:
yolo_detector = YOLOClass(1)

Using device: mps
Classes:  {0: 'cross', 1: 'skyward', 2: 't-pose'}


In [None]:
yolo_detector.train(config_file_path="data.yaml", epochs=10, imagesize=640, device="mps", verbose=True)

In [4]:
yolo_detector()




0: 384x640 (no detections), 84.9ms
Speed: 2.2ms preprocess, 84.9ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)
FPS: 4.08

0: 384x640 1 cross, 60.5ms
Speed: 1.2ms preprocess, 60.5ms inference, 0.5ms postprocess per image at shape (1, 3, 384, 640)
FPS: 11.38

0: 384x640 1 cross, 53.5ms
Speed: 1.4ms preprocess, 53.5ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)
FPS: 13.47

0: 384x640 1 cross, 51.6ms
Speed: 1.2ms preprocess, 51.6ms inference, 0.5ms postprocess per image at shape (1, 3, 384, 640)
FPS: 13.64

0: 384x640 1 cross, 68.3ms
Speed: 1.5ms preprocess, 68.3ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)
FPS: 10.42

0: 384x640 1 cross, 59.3ms
Speed: 1.3ms preprocess, 59.3ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)
FPS: 14.07

0: 384x640 1 cross, 52.8ms
Speed: 1.8ms preprocess, 52.8ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)
FPS: 12.81

0: 384x640 1 cross, 53.6ms
Speed: 1.2ms 

KeyboardInterrupt: 