In [None]:
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
import pandas as pd

def parse_yolo_label(label_path):
    """
    YOLOフォーマットのラベルを解析する関数
    """
    keypoints_list = []
    class_ids = []
    bboxes = []
    
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                bbox = [float(x) for x in parts[1:5]]
    
                keypoints = []
                for i in range(5, len(parts), 3):
                    if i+2 < len(parts):
                        kp = [float(parts[i]), float(parts[i+1]), int(parts[i+2])]
                        keypoints.append(kp)
                
                class_ids.append(class_id)
                bboxes.append(bbox)
                keypoints_list.append(np.array(keypoints))
    
    return class_ids, bboxes, keypoints_list

def calculate_pcp(pred_keypoints, gt_keypoints, visible_threshold=0, threshold=0.5):
    visible_indices = np.where(gt_keypoints[:, 2] > visible_threshold)[0]
    
    if len(visible_indices) < 2:
        return 0, 0, 0 

    keypoint_pairs = []
    for i in range(len(visible_indices)-1):
        keypoint_pairs.append((visible_indices[i], visible_indices[i+1]))
    
    if not keypoint_pairs:
        return 0, 0, 0  
    max_dist = 0
    for i in range(len(visible_indices)):
        for j in range(i+1, len(visible_indices)):
            idx1, idx2 = visible_indices[i], visible_indices[j]
            dist = np.sqrt(((gt_keypoints[idx1, :2] - gt_keypoints[idx2, :2]) ** 2).sum())
            max_dist = max(max_dist, dist)
    
    if max_dist == 0:
        return 0, 0, 0  
    correct_parts = 0
    valid_parts = len(keypoint_pairs)  
    for joint1, joint2 in keypoint_pairs:
        if joint1 >= len(pred_keypoints) or joint2 >= len(pred_keypoints):
            continue
 
        dist1 = np.sqrt(((pred_keypoints[joint1, :2] - gt_keypoints[joint1, :2]) ** 2).sum())
        dist2 = np.sqrt(((pred_keypoints[joint2, :2] - gt_keypoints[joint2, :2]) ** 2).sum())
      
        norm_dist1 = dist1 / max_dist
        norm_dist2 = dist2 / max_dist
 
        part_length = np.sqrt(((gt_keypoints[joint1, :2] - gt_keypoints[joint2, :2]) ** 2).sum())
        half_part_length = part_length / 2
        
        if (norm_dist1 <= threshold and norm_dist2 <= threshold and
            dist1 <= half_part_length and dist2 <= half_part_length):
            correct_parts += 1

    pcp = correct_parts / max(valid_parts, 1) 
    return pcp, correct_parts, valid_parts

# def calculate_pcp(pred_keypoints, gt_keypoints, visible_threshold=0, threshold=0.5):
#     visible_indices = np.where(gt_keypoints[:, 2] > visible_threshold)[0]
    
#     if len(visible_indices) < 2:
#         return 0, 0, 0 

#     keypoint_pairs = []
#     for i in range(len(visible_indices)-1):
#         keypoint_pairs.append((visible_indices[i], visible_indices[i+1]))
    
#     if not keypoint_pairs:
#         return 0, 0, 0  
#     max_dist = 0
#     for i in range(len(visible_indices)):
#         for j in range(i+1, len(visible_indices)):
#             idx1, idx2 = visible_indices[i], visible_indices[j]
#             dist = np.sqrt(((gt_keypoints[idx1, :2] - gt_keypoints[idx2, :2]) ** 2).sum())
#             max_dist = max(max_dist, dist)
    
#     if max_dist == 0:
#         return 0, 0, 0  
#     correct_parts = 0
#     valid_parts = len(keypoint_pairs)  
#     for joint1, joint2 in keypoint_pairs:
#         if joint1 >= len(pred_keypoints) or joint2 >= len(pred_keypoints):
#             continue
            
#         dist1 = np.sqrt(((pred_keypoints[joint1, :2] - gt_keypoints[joint1, :2]) ** 2).sum())
#         dist2 = np.sqrt(((pred_keypoints[joint2, :2] - gt_keypoints[joint2, :2]) ** 2).sum())
      
#         norm_dist1 = dist1 / max_dist
#         norm_dist2 = dist2 / max_dist
        
#         if norm_dist1 <= threshold and norm_dist2 <= threshold:
#             correct_parts += 1

#     pcp = correct_parts / max(valid_parts, 1) 
#     return pcp, correct_parts, valid_parts

def evaluate_on_dataset(model, dataset_path, img_size=640, threshold=0.5):
    images_path = os.path.join(dataset_path, 'images')
    labels_path = os.path.join(dataset_path, 'labels')
    
    total_correct_parts = 0
    total_valid_parts = 0
    pcp_scores = []
    class_stats = {0: {'correct': 0, 'valid': 0, 'scores': []}, 
                   1: {'correct': 0, 'valid': 0, 'scores': []}}

    results_details = []

    image_files = glob.glob(os.path.join(images_path, '*.jpg')) + \
                 glob.glob(os.path.join(images_path, '*.jpeg')) + \
                 glob.glob(os.path.join(images_path, '*.png'))
    
    for img_path in tqdm(image_files):
        base_name = os.path.basename(img_path)
        name_without_ext = os.path.splitext(base_name)[0]
        label_file = os.path.join(labels_path, name_without_ext + '.txt')
        if not os.path.exists(label_file):
            continue
        img = cv2.imread(img_path)
        if img is None:
            continue     
        height, width = img.shape[:2]
    
        try:
            gt_classes, gt_bboxes, gt_keypoints_list = parse_yolo_label(label_file)
        except Exception as e:
            print(f"Error parsing label file {label_file}: {e}")
            continue
        
        if not gt_classes:
            continue
      
        results = model(img, verbose=False)
        
        if len(results[0].keypoints.data) == 0:
            continue

        for gt_idx, (gt_class, gt_bbox, gt_keypoints) in enumerate(zip(gt_classes, gt_bboxes, gt_keypoints_list)):
            gt_x_center, gt_y_center, gt_width, gt_height = gt_bbox
            gt_x_center *= width
            gt_y_center *= height
            gt_width *= width
            gt_height *= height
            
            gt_x1 = gt_x_center - gt_width / 2
            gt_y1 = gt_y_center - gt_height / 2
            gt_x2 = gt_x_center + gt_width / 2
            gt_y2 = gt_y_center + gt_height / 2
            
            for i in range(len(gt_keypoints)):
                if gt_keypoints[i, 2] > 0:
                    gt_keypoints[i, 0] *= width
                    gt_keypoints[i, 1] *= height
            
            best_match_idx = -1
            best_iou = 0
            
            for pred_idx, box in enumerate(results[0].boxes.xyxy.cpu().numpy()):
                pred_x1, pred_y1, pred_x2, pred_y2 = box
                
                inter_x1 = max(gt_x1, pred_x1)
                inter_y1 = max(gt_y1, pred_y1)
                inter_x2 = min(gt_x2, pred_x2)
                inter_y2 = min(gt_y2, pred_y2)
                
                if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                    inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
                    gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
                    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
                    union_area = gt_area + pred_area - inter_area
                    iou = inter_area / union_area
                    
                    if iou > best_iou:
                        best_iou = iou
                        best_match_idx = pred_idx

            if best_match_idx >= 0 and best_iou > 0.5: 
                pred_class = int(results[0].boxes.cls[best_match_idx].cpu().numpy().item())
                pred_keypoints = results[0].keypoints.data[best_match_idx].cpu().numpy()
                
                # PCPを計算
                pcp, correct, valid = calculate_pcp(pred_keypoints, gt_keypoints, threshold=threshold)
                
                if valid > 0:
                    total_correct_parts += correct
                    total_valid_parts += valid
                    pcp_scores.append(pcp)

                    class_stats[gt_class]['correct'] += correct
                    class_stats[gt_class]['valid'] += valid
                    class_stats[gt_class]['scores'].append(pcp)
                    
                    # 結果詳細を保存
                    results_details.append({
                        'image': base_name,
                        'gt_class': gt_class,
                        'pred_class': pred_class,
                        'iou': best_iou,
                        'pcp': pcp,
                        'correct_parts': correct,
                        'valid_parts': valid
                    })
    
    overall_pcp = total_correct_parts / max(total_valid_parts, 1)
    class_pcps = {}
    for cls, stats in class_stats.items():
        cls_pcp = stats['correct'] / max(stats['valid'], 1)
        mean_pcp = np.mean(stats['scores']) if stats['scores'] else 0
        class_pcps[cls] = {
            'pcp': cls_pcp,
            'mean_pcp': mean_pcp,
            'count': len(stats['scores'])
        }

    details_df = pd.DataFrame(results_details)
    
    return overall_pcp, pcp_scores, class_pcps, details_df

def main():
    model_path = '../../AI_aug_gen/__output__/06_YOLOv8_Pose/1211_leafBlock/yolov8n-pose/20241211-02304906_yolov8n-pose_1211_lb_1000_n1050_kpt/weights/best.pt'
    model = YOLO(model_path)
    dataset_path = '../../AI_aug_gen/__dataset__/06_YOLOv8_Pose/1211_leafBlock_500/1211_lb_500_n550_kpt/test/'
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
    class_names = {0: "クラス0", 1: "grape_cane"}

    all_results = {}
    for threshold in thresholds:
        print(f'Evaluating with threshold: {threshold}')
        overall_pcp, pcp_scores, class_pcps, details_df = evaluate_on_dataset(
            model, dataset_path, threshold=threshold)

        details_df.to_csv(f'pcp_details_threshold_{threshold}.csv', index=False)
        
        all_results[threshold] = {
            'overall_pcp': overall_pcp,
            'mean_pcp': np.mean(pcp_scores) if pcp_scores else 0,
            'median_pcp': np.median(pcp_scores) if pcp_scores else 0,
            'num_samples': len(pcp_scores),
            'class_pcps': class_pcps
        }
        
        print(f'  Overall PCP: {overall_pcp:.4f}')
        print(f'  Mean PCP: {all_results[threshold]["mean_pcp"]:.4f}')
        print(f'  Median PCP: {all_results[threshold]["median_pcp"]:.4f}')
        print(f'  Samples: {all_results[threshold]["num_samples"]}')
        
        for cls, stats in class_pcps.items():
            print(f'  Class {class_names[cls]}:')
            print(f'    PCP: {stats["pcp"]:.4f}')
            print(f'    Mean PCP: {stats["mean_pcp"]:.4f}')
            print(f'    Count: {stats["count"]}')
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    plt.plot(thresholds, [all_results[t]['overall_pcp'] for t in thresholds], marker='o', label='Overall PCP')
    plt.plot(thresholds, [all_results[t]['mean_pcp'] for t in thresholds], marker='s', label='Mean PCP')
    plt.plot(thresholds, [all_results[t]['median_pcp'] for t in thresholds], marker='^', label='Median PCP')
    plt.xlabel('Threshold')
    plt.ylabel('PCP')
    plt.title('Overall PCP vs Threshold')
    plt.grid(True)
    plt.legend()
    plt.subplot(2, 2, 2)
    for cls in class_names:
        if any(all_results[t]['class_pcps'].get(cls, {}).get('count', 0) > 0 for t in thresholds):
            plt.plot(thresholds, 
                     [all_results[t]['class_pcps'].get(cls, {}).get('pcp', 0) for t in thresholds], 
                     marker='o', label=f'Class {class_names[cls]} PCP')
    plt.xlabel('Threshold')
    plt.ylabel('PCP')
    plt.title('Class-wise PCP vs Threshold')
    plt.grid(True)
    plt.legend()
    plt.subplot(2, 2, 3)
    for cls in class_names:
        if any(all_results[t]['class_pcps'].get(cls, {}).get('count', 0) > 0 for t in thresholds):
            plt.plot(thresholds, 
                     [all_results[t]['class_pcps'].get(cls, {}).get('mean_pcp', 0) for t in thresholds], 
                     marker='s', label=f'Class {class_names[cls]} Mean PCP')
    plt.xlabel('Threshold')
    plt.ylabel('Mean PCP')
    plt.title('Class-wise Mean PCP vs Threshold')
    plt.grid(True)
    plt.legend()

    plt.subplot(2, 2, 4)
    plt.bar(range(len(class_names)), 
            [all_results[thresholds[0]]['class_pcps'].get(cls, {}).get('count', 0) for cls in class_names],
            tick_label=[class_names[cls] for cls in class_names])
    plt.xlabel('Class')
    plt.ylabel('Number of Samples')
    plt.title('Number of Evaluated Samples per Class')
    plt.grid(True, axis='y')
    
    plt.tight_layout()
    plt.savefig('pcp_results.png', dpi=300)
    plt.show()

    summary_df = []
    for t in thresholds:
        row = {
            'Threshold': t,
            'Overall_PCP': all_results[t]['overall_pcp'],
            'Mean_PCP': all_results[t]['mean_pcp'],
            'Median_PCP': all_results[t]['median_pcp'],
            'Samples': all_results[t]['num_samples']
        }
        for cls in class_names:
            if cls in all_results[t]['class_pcps']:
                stats = all_results[t]['class_pcps'][cls]
                row[f'Class_{class_names[cls]}_PCP'] = stats['pcp']
                row[f'Class_{class_names[cls]}_Mean_PCP'] = stats['mean_pcp']
                row[f'Class_{class_names[cls]}_Count'] = stats['count']
        summary_df.append(row)
    
    pd.DataFrame(summary_df).to_csv('pcp_summary.csv', index=False)
    print("Evaluation complete. Results saved to CSV files and plots.")

if __name__ == "__main__":
    main()