In [64]:
import os
import sys
import warnings

warnings.filterwarnings('ignore')

import cv2
import mmcv
import mmengine
import numpy as np
import torch
from mmdet.apis import inference_detector, init_detector

from files import FileName
from mmpose.evaluation.functional import nms
from mmpose.utils import adapt_mmdet_pipeline
from utils import calculate_iou, get_file_name

In [65]:
img_path = "/home/ohwada/GSAT/MMPE/examples/img2.jpg"
video_path = "/home/ohwada/GSAT/video/sample3.mp4"

In [66]:
model_path = "/home/ohwada/GSAT/MMPE/models/rtmdet_m_8xb32-300e_coco.py"
config_path = "/home/ohwada/GSAT/MMPE/models/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth"

In [67]:
det_cat_id: int = 0
bbox_thr: float = 0.3
nms_thr: float = 0.3
iou_thr: float = 0.1

In [68]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load the model
detector = init_detector(model_path, config_path, device=device)
detector.cfg = adapt_mmdet_pipeline(detector.cfg)

Using device: cuda
Loads checkpoint by local backend from path: /home/ohwada/GSAT/MMPE/models/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth
The model and loaded state dict do not match exactly

unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std



In [69]:
cap = cv2.VideoCapture(video_path)

_, img = cap.read()

In [70]:
det_result = inference_detector(detector, img)

In [71]:
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == det_cat_id, pred_instance.scores > bbox_thr)]

In [72]:
cap = cv2.VideoCapture(video_path)
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (int(cap.get(3)), int(cap.get(4))))

select_frame = 1
conf_rank = 1

frame_idx = 1
tracked_box = None

while True:
    ret, img = cap.read()
    if not ret:
        break
    
    det_result = inference_detector(detector, img)
    
    pred_instance = det_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == det_cat_id, pred_instance.scores > bbox_thr)]
    
    if frame_idx == select_frame:
        sorted_bboxes = sorted(bboxes, key=lambda x: x[-1], reverse=True)
        tracked_box = sorted_bboxes[conf_rank - 1][:4]
        score = sorted_bboxes[conf_rank - 1][-1]
    elif tracked_box is not None:
        max_iou = 0
        for bbox in bboxes:
            box = bbox[:4]
            score = bbox[-1]
            iou = calculate_iou(tracked_box, box)
            if iou > max_iou:
                max_iou = iou
                tracked_box = box
        if max_iou < iou_thr:
            tracked_box = None
    
    if tracked_box is not None:
        x1, y1, x2, y2 = map(int, tracked_box)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(img, f"Conf: {score:.3f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
    
    out.write(img)
    
    frame_idx += 1

cap.release()
out.release()