In [None]:
import cv2
import os
import torch
import numpy as np
from ultralytics import YOLO

# Load the segmentation and classification models
segmentation_model = YOLO('runs/segment/yolov8s_on_groundedsam_selected/weights/best.pt')
classification_model = YOLO("runs/classify/train/weights/best.pt")

# Directories
input_videos_dir = '../../DATA/Sampled_Test_Videos/'
output_videos_dir = 'results/segmentation_classification_videos'
os.makedirs(output_videos_dir, exist_ok=True)

# Batch size for processing
batch_size = 16

# Function to process and classify objects
def classify_objects(crops, model):
    predictions = []
    for batch_start in range(0, len(crops), batch_size):
        batch = crops[batch_start:batch_start + batch_size]
        predictions += model(batch, imgsz=224, verbose=False, device = 0)
    return predictions

# Process each video
for video_name in os.listdir(input_videos_dir):
    video_path = os.path.join(input_videos_dir, video_name)

    # Open video file
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Define codec and create a VideoWriter
    output_video_path = os.path.join(output_videos_dir, f"output_{video_name}")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)

        # Process in batches
        if len(frames) >= batch_size:
            frames_batch = np.stack(frames)
            results = segmentation_model(frames_batch, device=0, imgsz=(640, 640), verbose=False)

            for i, result in enumerate(results):
                frame = frames[i]
                masks = result.masks
                bboxes = result.boxes
                crops = []

                for mask, box in zip(masks, bboxes):
                    x1, y1, x2, y2 = map(int, box.xyxy[0])
                    crop = frame[y1:y2, x1:x2]
                    if crop.size > 0:
                        crops.append(crop)

                # Classify crops
                if crops:
                    classifications = classify_objects(crops, classification_model)

                    for cls_result, mask, box in zip(classifications, masks, bboxes):
                        x1, y1, x2, y2 = map(int, box.xyxy[0])
                        pred_class = cls_result.probs.top1
                        pred_label = classification_model.names[pred_class]
                        pred_conf = cls_result.probs.top1conf.item()

                        # Overlay bounding box, mask, and classification result
                        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                        cv2.putText(frame, f"{pred_label} ({pred_conf:.2f})", (x1, y1 - 10),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

                # Save frame to output video
                out.write(frame)

            frames = []  # Clear batch

    # Process remaining frames
    if frames:
        for frame in frames:
            out.write(frame)

    # Release resources
    cap.release()
    out.release()

print("All videos processed and saved.")