In [None]:
# pip3 install torch torchvision
# git clone https://github.com/ultralytics/yolov5
# cd yolov5
# pip3 install -r requirements.txt


In [None]:
# wget https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n.pt


In [None]:
# !python3 detect.py --weights yolov5n.pt --source 0  # Run on a USB or Pi camera

In [None]:
import torch
import cv2
from numpy import random

# Load the YOLOv5n model
model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5n.pt')

# Define the list of specific class IDs you want to detect
# Replace with your actual target class IDs from the COCO dataset (e.g., 0 for 'person', 2 for 'car', etc.)
target_classes = [0, 1, 2, 3, 5, 7, 9, 10, 12, 15, 16, 20, 23, 25, 27, 30, 32, 35, 38, 40]  # Example IDs

# Function to filter detection results by target class IDs
def filter_classes(detections, target_classes):
    filtered_results = []
    for det in detections:
        if int(det[5]) in target_classes:  # det[5] contains the class ID
            filtered_results.append(det)
    return filtered_results


cap = cv2.VideoCapture(0)  # Use 0 for the default camera, or a path to a video file

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Perform inference on the frame
    results = model(frame)

    # Extract detection results
    detections = results.xyxy[0].numpy()

    # Filter based on specific class IDs
    filtered_detections = filter_classes(detections, target_classes)

    # Draw filtered detections on the frame
    for det in filtered_detections:
        x1, y1, x2, y2, conf, cls = det
        label = f"{model.names[int(cls)]} {conf:.2f}"
        color = [random.randint(0, 255) for _ in range(3)]
        cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
        cv2.putText(frame, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    # Show the result in real-time
    cv2.imshow('Filtered Detections', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
