In [None]:
import os
import numpy as np
import re
import torch
from ultralytics import YOLO
from ultralytics.utils.metrics import Metric, bbox_iou, ap_per_class
from sklearn.metrics import average_precision_score

def f1(p, r):
    return 2 * p * r / (p + r + 1e-16)

def load_labels(file_path):
    if not os.path.exists(file_path):
        return []
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    return [list(map(float, line.strip().split())) for line in lines if line.strip()]

def parse_frame_number(filename, prefix):
    match = re.match(rf'{re.escape(prefix)}(\d+)\.(txt|jpg)$', filename)
    if match:
        return int(match.group(1))
    return None

def get_all_labels(true_folder, pred_folder, image_folder, trajectory_prefix, frame_range=None):
    true_labels_dict = {}
    pred_labels_dict = {}

    # 1. Extract frame numbers from the image file
    frame_set = set()
    for fname in os.listdir(image_folder):
        if fname.endswith('.jpg') and fname.startswith(trajectory_prefix):
            frame = parse_frame_number(fname, trajectory_prefix)
            if frame is not None:
                if frame_range is None or (frame_range[0] <= frame <= frame_range[1]):
                    frame_set.add(frame)

    # 2. Load the real labels
    for f in os.listdir(true_folder):
        if not f.endswith('.txt'):
            continue
        frame = parse_frame_number(f, trajectory_prefix)
        if frame is not None and frame in frame_set:
            true_labels_dict[frame] = load_labels(os.path.join(true_folder, f))

    # 3. Load prediction labels
    for f in os.listdir(pred_folder):
        if not f.endswith('.txt'):
            continue
        frame = parse_frame_number(f, trajectory_prefix)
        if frame is not None and frame in frame_set:
            pred_labels_dict[frame] = load_labels(os.path.join(pred_folder, f))

    # 4. Uniform sequence
    all_frames = sorted(frame_set)
    all_true_labels = [true_labels_dict.get(f, []) for f in all_frames]
    all_pred_labels = [pred_labels_dict.get(f, []) for f in all_frames]

    return all_true_labels, all_pred_labels

def convert_to_yolo_format(labels, is_ts=False):
    result = []
    for label in labels:
        if len(label) == 6:
            result.append([label[0], *label[1:]])
        elif len(label) == 5:
            result.append([label[0], *label[1:], 1.0])
    return result

def calculate_yolo_metrics(true_folder, pred_folder, image_folder, trajectory_prefix, is_ts=False, frame_range=None):
    all_true_labels, all_pred_labels = get_all_labels(true_folder, pred_folder, image_folder, trajectory_prefix, frame_range)
    true_labels = [convert_to_yolo_format(l) for l in all_true_labels]
    pred_labels = [convert_to_yolo_format(l, is_ts=is_ts) for l in all_pred_labels]

    all_tp, all_pred_scores, all_pred_cls, all_target_cls = [], [], [], []

    for true, pred in zip(true_labels, pred_labels):
        if not true and not pred:
            continue
        true_boxes = torch.tensor([l[1:5] for l in true]) if true else torch.empty((0, 4))
        pred_boxes = torch.tensor([l[1:5] for l in pred]) if pred else torch.empty((0, 4))
        pred_scores = torch.tensor([l[5] for l in pred]) if pred else torch.empty(0)
        pred_cls = torch.zeros(len(pred))
        target_cls = torch.zeros(len(true))
        tp = torch.zeros((len(pred), 10), dtype=torch.float32)
        if true_boxes.size(0) > 0 and pred_boxes.size(0) > 0:
            ious = bbox_iou(pred_boxes, true_boxes, xywh=True)
            ious_max, _ = ious.max(1)
            for i in range(10):
                tp[:, i] = ious_max > (0.5 + 0.05 * i)
        all_tp.append(tp)
        all_pred_scores.append(pred_scores)
        all_pred_cls.append(pred_cls)
        all_target_cls.append(target_cls)

    all_tp = torch.cat(all_tp, dim=0) if all_tp else torch.empty((0, 10))
    all_pred_scores = torch.cat(all_pred_scores, dim=0) if all_pred_scores else torch.empty(0)
    all_pred_cls = torch.cat(all_pred_cls, dim=0) if all_pred_cls else torch.empty(0)
    all_target_cls = torch.cat(all_target_cls, dim=0) if all_target_cls else torch.empty(0)

    if all_tp.numel() == 0:
        return {'mp': [0], 'mr': [0], 'map50': [0], 'map': 0, 'map75': 0}

    results = ap_per_class(all_tp, all_pred_scores, all_pred_cls, all_target_cls, names={})
    mean_results = results[2:6]
    
    pr_auc = 0.0
    if all_tp.size(0) > 0 and all_pred_scores.numel() > 0:
        y_true = all_tp[:, 0].numpy()
        y_scores = all_pred_scores.numpy()
        try:
            pr_auc = average_precision_score(y_true, y_scores)
        except:
            pr_auc = 0.0
    
    return {
        'mp': mean_results[0],
        'mr': mean_results[1],
        'map50': mean_results[2],
        'map': results[5].mean(),
        'map75': results[5][:, 5].mean(),
        'auc': pr_auc
    }

# ========== Set trajectory information ==========
trajectory_name = '0619-3'
trajectory_prefix = trajectory_name + '_'

# ========== 路径配置 ==========
true_folder = 'datasets/UAV/labels'
yolov12_pred_folder = f'predict_results/{trajectory_name}'
yolov12_ts_pred_folder = f'filter_results/{trajectory_name}'
image_folder = f'predict_datasets/{trajectory_name}'

# ========== Set the frame number range ==========
frame_range = (2500,16000)

# ========== YOLOv12 ==========
metrics_yolo = calculate_yolo_metrics(
    true_folder, yolov12_pred_folder, image_folder, trajectory_prefix,
    is_ts=False, frame_range=frame_range
)

print('[YOLOv12]')
print('Precision:', metrics_yolo["mp"][0])
print('Recall:', metrics_yolo["mr"][0])
print('mAP50:', metrics_yolo["map50"][0])
print('mAP:', metrics_yolo["map"])
print('F1:', f1(metrics_yolo["mp"][0], metrics_yolo["mr"][0]))
print('AUC (PR-AUC):', metrics_yolo["auc"])

# ========== YOLOv12-TS ==========
metrics_yolo_ts = calculate_yolo_metrics(
    true_folder, yolov12_ts_pred_folder, image_folder, trajectory_prefix,
    is_ts=True, frame_range=frame_range
)

print('[YOLOv12-TS]')
print('Precision:', metrics_yolo_ts["mp"][0])
print('Recall:', metrics_yolo_ts["mr"][0])
print('mAP50:', metrics_yolo_ts["map50"][0])
print('mAP:', metrics_yolo_ts["map"])
print('F1:', f1(metrics_yolo_ts["mp"][0], metrics_yolo_ts["mr"][0]))
print('AUC (PR-AUC):', metrics_yolo_ts["auc"])