最大誤差は512*sqrt(2)

In [None]:
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
import pandas as pd
import seaborn as sns

def parse_yolo_label(label_path):
    keypoints_list = []
    class_ids = []
    bboxes = []
    
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                bbox = [float(x) for x in parts[1:5]]
    
                keypoints = []
                for i in range(5, len(parts), 3):
                    if i+2 < len(parts):
                        kp = [float(parts[i]), float(parts[i+1]), int(parts[i+2])]
                        keypoints.append(kp)
                
                class_ids.append(class_id)
                bboxes.append(bbox)
                keypoints_list.append(np.array(keypoints))
    
    return class_ids, bboxes, keypoints_list

def calculate_keypoint_errors(pred_keypoints, gt_keypoints, visible_threshold=0):
    errors = []
    visible_indices = np.where(gt_keypoints[:, 2] > visible_threshold)[0]
    
    for idx in visible_indices:
        if idx < len(pred_keypoints):
            # 各キーポイントのユークリッド距離を計算
            error = np.sqrt(np.sum((pred_keypoints[idx, :2] - gt_keypoints[idx, :2]) ** 2))
            errors.append({
                'keypoint_idx': idx,
                'error': error
            })
    
    return errors, len(visible_indices)

def evaluate_on_dataset(model, dataset_path, img_size=640):
    images_path = os.path.join(dataset_path, 'images')
    labels_path = os.path.join(dataset_path, 'labels')
    
    all_keypoint_errors = []
    keypoint_errors_by_class = {0: [], 1: []}
    
    results_details = []

    total_gt_keypoints = 0
    total_visible_gt_keypoints = 0
    total_detected_keypoints = 0
    class_keypoint_stats = {0: {'total': 0, 'visible': 0, 'detected': 0},
                           1: {'total': 0, 'visible': 0, 'detected': 0}}

    image_files = glob.glob(os.path.join(images_path, '*.jpg')) + \
                 glob.glob(os.path.join(images_path, '*.jpeg')) + \
                 glob.glob(os.path.join(images_path, '*.png'))
    
    for img_path in tqdm(image_files):
        base_name = os.path.basename(img_path)
        name_without_ext = os.path.splitext(base_name)[0]
        label_file = os.path.join(labels_path, name_without_ext + '.txt')
        
        if not os.path.exists(label_file):
            continue
            
        img = cv2.imread(img_path)
        if img is None:
            continue     
        height, width = img.shape[:2]
    
        try:
            gt_classes, gt_bboxes, gt_keypoints_list = parse_yolo_label(label_file)
        except Exception as e:
            print(f"Error parsing label file {label_file}: {e}")
            continue
        
        if not gt_classes:
            continue
      
        results = model(img, verbose=False)
        
        for gt_idx, (gt_class, gt_bbox, gt_keypoints) in enumerate(zip(gt_classes, gt_bboxes, gt_keypoints_list)):
            total_gt_keypoints += len(gt_keypoints)
            class_keypoint_stats[gt_class]['total'] += len(gt_keypoints)
            
            visible_keypoints = np.sum(gt_keypoints[:, 2] > 0)
            total_visible_gt_keypoints += visible_keypoints
            class_keypoint_stats[gt_class]['visible'] += visible_keypoints
            
            gt_x_center, gt_y_center, gt_width, gt_height = gt_bbox
            gt_x_center *= width
            gt_y_center *= height
            gt_width *= width
            gt_height *= height
            
            gt_x1 = gt_x_center - gt_width / 2
            gt_y1 = gt_y_center - gt_height / 2
            gt_x2 = gt_x_center + gt_width / 2
            gt_y2 = gt_y_center + gt_height / 2
            
            for i in range(len(gt_keypoints)):
                if gt_keypoints[i, 2] > 0:
                    gt_keypoints[i, 0] *= width
                    gt_keypoints[i, 1] *= height
            
            best_match_idx = -1
            best_iou = 0
            
            if len(results[0].keypoints.data) > 0:
                for pred_idx, box in enumerate(results[0].boxes.xyxy.cpu().numpy()):
                    pred_x1, pred_y1, pred_x2, pred_y2 = box
                    
                    inter_x1 = max(gt_x1, pred_x1)
                    inter_y1 = max(gt_y1, pred_y1)
                    inter_x2 = min(gt_x2, pred_x2)
                    inter_y2 = min(gt_y2, pred_y2)
                    
                    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                        inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
                        gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
                        pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
                        union_area = gt_area + pred_area - inter_area
                        iou = inter_area / union_area
                        
                        if iou > best_iou:
                            best_iou = iou
                            best_match_idx = pred_idx

            if best_match_idx >= 0 and best_iou > 0.5: 
                pred_class = int(results[0].boxes.cls[best_match_idx].cpu().numpy().item())
                pred_keypoints = results[0].keypoints.data[best_match_idx].cpu().numpy()
                
                errors, visible_count = calculate_keypoint_errors(pred_keypoints, gt_keypoints)
               
                detected_count = len(errors)
                total_detected_keypoints += detected_count
                class_keypoint_stats[gt_class]['detected'] += detected_count
                
                for error_info in errors:
                    error_info['image'] = base_name
                    error_info['gt_class'] = gt_class
                    error_info['pred_class'] = pred_class
                    error_info['iou'] = best_iou
                    all_keypoint_errors.append(error_info)
                    keypoint_errors_by_class[gt_class].append(error_info)
                    
                    results_details.append(error_info)
    
    details_df = pd.DataFrame(results_details)
    
    keypoint_stats = {
        'total_gt_keypoints': total_gt_keypoints,
        'total_visible_gt_keypoints': total_visible_gt_keypoints,
        'total_detected_keypoints': total_detected_keypoints,
        'class_stats': class_keypoint_stats
    }
    
    return all_keypoint_errors, keypoint_errors_by_class, details_df, keypoint_stats

def visualize_error_distribution(all_errors, class_errors, class_names):
    plt.figure(figsize=(15, 12))
   
    plt.subplot(2, 2, 1)
    error_values = [e['error'] for e in all_errors]
    sns.histplot(error_values, kde=True)
    plt.xlabel('Error (pixels)')
    plt.ylabel('Frequency')
    plt.title('Overall Keypoint Error Distribution')
    plt.grid(True)
   
    plt.subplot(2, 2, 2)
    for cls, errors in class_errors.items():
        if errors:
            error_values = [e['error'] for e in errors]
            sns.histplot(error_values, kde=True, label=f'Class {class_names[cls]}')
    plt.xlabel('Error (pixels)')
    plt.ylabel('Frequency')
    plt.title('Keypoint Error Distribution by Class')
    plt.grid(True)
    plt.legend()
    
    plt.subplot(2, 2, 3)
    keypoint_data = []
    keypoint_classes = []
    keypoint_indices = []
    
    for cls, errors in class_errors.items():
        for error in errors:
            keypoint_data.append(error['error'])
            keypoint_classes.append(class_names[cls])
            keypoint_indices.append(error['keypoint_idx'])
    
    error_df = pd.DataFrame({
        'Error': keypoint_data,
        'Class': keypoint_classes,
        'Keypoint': keypoint_indices
    })
    
    sns.boxplot(x='Keypoint', y='Error', data=error_df)
    plt.xlabel('Keypoint Index')
    plt.ylabel('Error (pixels)')
    plt.title('Error Distribution by Keypoint Index')
    plt.grid(True)
    
    plt.subplot(2, 2, 4)
    for cls, errors in class_errors.items():
        if errors:
            error_values = [e['error'] for e in errors]
            sorted_errors = np.sort(error_values)
            p = 1. * np.arange(len(sorted_errors)) / (len(sorted_errors) - 1)
            plt.plot(sorted_errors, p, label=f'Class {class_names[cls]}')
    
    plt.xlabel('Error (pixels)')
    plt.ylabel('Cumulative Probability')
    plt.title('Cumulative Distribution of Keypoint Errors')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('keypoint_error_distribution.png', dpi=300)
    plt.show()
    
    stats = {
        'overall': {
            'mean': np.mean([e['error'] for e in all_errors]),
            'median': np.median([e['error'] for e in all_errors]),
            'std': np.std([e['error'] for e in all_errors]),
            'min': np.min([e['error'] for e in all_errors]),
            'max': np.max([e['error'] for e in all_errors]),
            'count': len(all_errors)
        }
    }
    
    for cls, errors in class_errors.items():
        if errors:
            error_values = [e['error'] for e in errors]
            stats[class_names[cls]] = {
                'mean': np.mean(error_values),
                'median': np.median(error_values),
                'std': np.std(error_values),
                'min': np.min(error_values),
                'max': np.max(error_values),
                'count': len(errors)
            }
    
    return stats

def visualize_keypoint_detection_stats(keypoint_stats, class_names):
    plt.figure(figsize=(15, 10))
 
    plt.subplot(2, 2, 1)
    categories = ['Total', 'Visible', 'Detected']
    values = [
        keypoint_stats['total_gt_keypoints'],
        keypoint_stats['total_visible_gt_keypoints'],
        keypoint_stats['total_detected_keypoints']
    ]
    
    plt.bar(categories, values)
    for i, v in enumerate(values):
        plt.text(i, v + 5, str(v), ha='center')
    
    plt.title('Overall Keypoint Statistics')
    plt.ylabel('Number of Keypoints')
    plt.grid(True, axis='y')

    plt.subplot(2, 2, 2)
    class_names_list = [class_names[cls] for cls in sorted(class_names.keys())]
    x = np.arange(len(class_names_list))
    width = 0.25
    
    totals = [keypoint_stats['class_stats'][cls]['total'] for cls in sorted(class_names.keys())]
    visibles = [keypoint_stats['class_stats'][cls]['visible'] for cls in sorted(class_names.keys())]
    detecteds = [keypoint_stats['class_stats'][cls]['detected'] for cls in sorted(class_names.keys())]
    
    plt.bar(x - width, totals, width, label='Total')
    plt.bar(x, visibles, width, label='Visible')
    plt.bar(x + width, detecteds, width, label='Detected')
    
    plt.title('Keypoint Statistics by Class')
    plt.ylabel('Number of Keypoints')
    plt.xticks(x, class_names_list)
    plt.legend()
    plt.grid(True, axis='y')

    plt.subplot(2, 2, 3)
    
    overall_detection_rate = keypoint_stats['total_detected_keypoints'] / max(keypoint_stats['total_visible_gt_keypoints'], 1) * 100
    class_detection_rates = [
        keypoint_stats['class_stats'][cls]['detected'] / max(keypoint_stats['class_stats'][cls]['visible'], 1) * 100
        for cls in sorted(class_names.keys())
    ]
    
    all_rates = [overall_detection_rate] + class_detection_rates
    all_labels = ['Overall'] + class_names_list
    
    plt.bar(all_labels, all_rates)
    for i, rate in enumerate(all_rates):
        plt.text(i, rate + 1, f'{rate:.1f}%', ha='center')
    
    plt.title('Keypoint Detection Rate')
    plt.ylabel('Detection Rate (%)')
    plt.ylim(0, 110)
    plt.grid(True, axis='y')
    
    plt.tight_layout()
    plt.savefig('keypoint_detection_stats.png', dpi=300)
    plt.show()
    
    return {
        'overall_detection_rate': overall_detection_rate,
        'class_detection_rates': {class_names[cls]: class_detection_rates[i] for i, cls in enumerate(sorted(class_names.keys()))}
    }

def main():
    # 全国大会
    model_path = '../../AI_aug_gen/__output__/06_YOLOv8_Pose/1211_leafBlock/yolov8n-pose/20241211-02304455_yolov8n-pose_1211_lb_500_n550_kpt/weights/best.pt'
    # WiNF
    #model_path = '../../AI_aug_gen/__output__/06_YOLOv8_Pose/cmp_cn_multicn/yolov8n-pose/20240726-16045940_yolov8n-pose_0719_multi_cn_n550_kpt/weights/best.pt'
    # 先行研究 2
    # model_path = '../../AI_aug_gen/__output__/06_YOLOv8_Pose/cmp_cn_multicn/yolov8n-pose/20240726-06285899_yolov8n-pose_0719_cn_n550_kpt/weights/best.pt'
    model = YOLO(model_path)
    dataset_path = '../../AI_aug_gen/__dataset__/06_YOLOv8_Pose/1211_leafBlock_500/1211_lb_500_n550_kpt/test/'
    class_names = {0: "クラス0", 1: "grape_cane"}

    print('Evaluating model and calculating keypoint errors...')
    all_errors, class_errors, details_df, keypoint_stats = evaluate_on_dataset(model, dataset_path)
    details_df.to_csv('keypoint_errors_details.csv', index=False)
    print('Visualizing error distributions...')
    error_stats = visualize_error_distribution(all_errors, class_errors, class_names)
    print('Visualizing keypoint detection statistics...')
    detection_stats = visualize_keypoint_detection_stats(keypoint_stats, class_names)
    print("\n=== Keypoint Detection Statistics ===")
    print(f"Total ground truth keypoints: {keypoint_stats['total_gt_keypoints']}")
    print(f"Total visible ground truth keypoints: {keypoint_stats['total_visible_gt_keypoints']}")
    print(f"Total detected keypoints: {keypoint_stats['total_detected_keypoints']}")
    print(f"Overall detection rate: {detection_stats['overall_detection_rate']:.2f}%")
    
    print("\nClass-wise keypoint statistics:")
    for cls, stats in keypoint_stats['class_stats'].items():
        print(f"  {class_names[cls]}:")
        print(f"    Total: {stats['total']}")
        print(f"    Visible: {stats['visible']}")
        print(f"    Detected: {stats['detected']}")
        print(f"    Detection rate: {stats['detected'] / max(stats['visible'], 1) * 100:.2f}%")
    
    print("\n=== Keypoint Error Statistics ===")
    print(f"Overall:")
    print(f"  Mean Error: {error_stats['overall']['mean']:.2f} pixels")
    print(f"  Median Error: {error_stats['overall']['median']:.2f} pixels")
    print(f"  Std Deviation: {error_stats['overall']['std']:.2f} pixels")
    print(f"  Min Error: {error_stats['overall']['min']:.2f} pixels")
    print(f"  Max Error: {error_stats['overall']['max']:.2f} pixels")
    print(f"  Count: {error_stats['overall']['count']} keypoints")
    
    for cls_name in [name for name in error_stats.keys() if name != 'overall']:
        print(f"\n{cls_name}:")
        print(f"  Mean Error: {error_stats[cls_name]['mean']:.2f} pixels")
        print(f"  Median Error: {error_stats[cls_name]['median']:.2f} pixels")
        print(f"  Std Deviation: {error_stats[cls_name]['std']:.2f} pixels")
        print(f"  Min Error: {error_stats[cls_name]['min']:.2f} pixels")
        print(f"  Max Error: {error_stats[cls_name]['max']:.2f} pixels")
        print(f"  Count: {error_stats[cls_name]['count']} keypoints")
    error_stats_rows = []
    for category, values in error_stats.items():
        row = {'Category': category}
        row.update(values)
        error_stats_rows.append(row)
    
    pd.DataFrame(error_stats_rows).to_csv('keypoint_error_stats.csv', index=False)
    detection_stats_rows = [
        {
            'Category': 'Overall',
            'Total': keypoint_stats['total_gt_keypoints'],
            'Visible': keypoint_stats['total_visible_gt_keypoints'],
            'Detected': keypoint_stats['total_detected_keypoints'],
            'Detection_Rate': detection_stats['overall_detection_rate']
        }
    ]
    
    for cls in sorted(class_names.keys()):
        detection_stats_rows.append({
            'Category': class_names[cls],
            'Total': keypoint_stats['class_stats'][cls]['total'],
            'Visible': keypoint_stats['class_stats'][cls]['visible'],
            'Detected': keypoint_stats['class_stats'][cls]['detected'],
            'Detection_Rate': keypoint_stats['class_stats'][cls]['detected'] / max(keypoint_stats['class_stats'][cls]['visible'], 1) * 100
        })
    
    pd.DataFrame(detection_stats_rows).to_csv('keypoint_detection_stats.csv', index=False)
    
    print("\nEvaluation complete. Results saved to CSV files and plots.")

if __name__ == "__main__":
    main()