In [1]:
import torch
import cv2
from torchvision.models.detection import ssd300_vgg16
from torchvision.transforms import functional as F

In [None]:
num_classes = 21  # 20 classes + background
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ssd300_vgg16(pretrained=True)  # Load pre-trained model
model.to(device)
model.eval()

In [3]:
def detect_objects(frame, model, device, threshold=0.5):
    # Convert frame to RGB (OpenCV uses BGR by default)
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Convert to a PyTorch tensor
    frame_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to(device)  # Shape (1, 3, H, W)

    # Put the model in evaluation mode and perform inference
    with torch.no_grad():  # Disable gradient calculation
        prediction = model(frame_tensor)

    # Get the predictions (boxes, labels, and scores)
    boxes = prediction[0]['boxes']
    labels = prediction[0]['labels']
    scores = prediction[0]['scores']

    # Apply a threshold to the scores
    mask = scores > threshold
    boxes = boxes[mask]
    labels = labels[mask]
    scores = scores[mask]

    # Draw bounding boxes and labels on the frame
    for i in range(len(boxes)):
        xmin, ymin, xmax, ymax = boxes[i].cpu().numpy()
        xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
        score = scores[i].item()

        # Draw bounding box
        cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)

        # Put label and score
        label = f'{score:.2f}'
        cv2.putText(frame, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

    return frame

In [None]:
import cv2

# Set the path to your video
video_path = "sample_video.mp4" 
cap = cv2.VideoCapture(video_path)

# Check if the video file was opened successfully
if not cap.isOpened():
    print("Error: Unable to open video file.")
    exit()

while True:
    ret, frame = cap.read()
    if not ret:
        print("Error: Cannot read frame.")
        break

    # Detect objects in the frame
    output_frame = detect_objects(frame, model, device, threshold=0.5)

    # Display the frame with detected objects
    cv2.imshow("Real-Time Object Detection", output_frame)

    # Break on pressing 'q'
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release resources
cap.release()
cv2.destroyAllWindows()
