# **Mean Average Precision (mAP)**
* The content is followed using "Mean Average Precision (mAP) Explained and PyTorch Implementation".<br>Reference: https://www.youtube.com/watch?v=FppOzcDvaDI&list=PLhhyoLH6Ijfw0TpCTVTNk42NN08H6UvNq&index=4
* Extended by **Vigyannveshi** 

In [2]:
import torch as tr
from collections import Counter
from object_detection_necessities import intersection_over_union

In [None]:
def mean_average_precision(
        pred_boxes,true_boxes,iou_threshold=0.5,box_format="corners",num_classes=20
        ):
    # pred_boxes (list): [[train_idx,class_pred,prob_score,x1,y1,x2,y2],[],[], ....]
    # true_boxes (list): [[train_idx,class_pred,prob_score,x1,y1,x2,y2],[],[], ....]


    average_precisions=[]
    epsilon=1e-6

    for c in range(num_classes):
        detections=[]
        ground_truths=[]

        ### get all the detection and true boxes for the particular class
        for detection in pred_boxes:
            if detection[1] ==c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1]==c:
                ground_truths.append(true_box)
        
        # img 0 has three bboxes
        # img 1 has 5 bboxes
        # amount_bboxes={0:3,1:5, ...}
        amount_bboxes=Counter([gt[0] for gt in ground_truths])

        for key,val in amount_bboxes.items():
            # amount_bboxes={0:tr.tensor([0,0,0], 1: tr.tensor([0,0,0,0,0]),...}
            amount_bboxes[key] = tr.zeros(val)

        # sort the detections based on probabilities in descending order    
        detections.sort(key=lambda x:x[2],reverse=True)

        TP=tr.zeros((len(detections)))
        FP=tr.zeros((len(detections)))
        total_true_bboxes=len(ground_truths)

        for detection_idx,detection in enumerate(detections):
            # take out ground truths that have same index as the detections
            ground_truth_img=[
                bbox for bbox in ground_truths if bbox[0]==detection[0]
            ]

            num_gts=len(ground_truth_img)
            best_iou=0
            
            for idx,gt in enumerate(ground_truth_img):
                iou=intersection_over_union(
                    tr.tensor(detection[3:],gt[3:]),box_format=box_format
                )
                if iou> best_iou:
                    best_iou=iou
                    best_gt_index=idx
            
            if best_iou>iou_threshold:
                if amount_bboxes[detection[0]][best_gt_index]==0:
                    TP[detection_idx]=1
                    amount_bboxes[detection[0]][best_gt_index]=1
                else:
                    FP[detection_idx]=1
            else:
                FP[detection_idx]=1
        
        # [1,1,0,1,0]--> [1,2,2,3,3]
        TP_cumsum=tr.cumsum(TP,dim=0)
        FP_cumsum=tr.cumsum(FP,dim=0)

        recalls=TP_cumsum/(total_true_bboxes+epsilon)
        precisions=tr.divide(TP_cumsum,(TP_cumsum+FP_cumsum+epsilon))                        

        precisions=tr.cat((tr.tensor([1]),precisions))
        recalls=tr.cat((tr.tensor([0]),recalls))

        average_precisions.append(tr.trapezoid(precisions,recalls))

    return sum(average_precisions)/len(average_precisions)