In [1]:
import torch
from iou import intersection_over_union

In [None]:
# Function for Non-Max Suppression
# this is used for finding the most optimal bounding box for each class in images

"""
Parameters:
    bboxes(list): lsit of bounding boxes as [[1, 0.9, x1, y1, x2, y2], ["], ...] 
    where 1: class of the object around whcih bounding box is there
        0.9: probability of how accurate that bounding box for the object
        rest four the corner points of the bounding box

    iou_threshold(range[0-1]): iou threshold hyperparameter
     threshold : threshold for the probability of bounding boxes to keep 
     box_format: corners/midpoint

"""

def nms(bboxes, iou_threshold, threshold, box_format="corners"):

    # bboxes = [[1, 0.9, x1, y1, x2, y2], ["], ...]

    assert type(bboxes) == list # bboxes type must be the list otherwise ASSERTION error is thrown

    # keeping the only bounding boxes in bboxes which meet the threshold criteria
    bboxes = [box for box in bboxes if box[1] > threshold] 

    # sorting the bboxes as per their probability in descending order
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)

    # list to store bounding boxes after non-max suppression
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0) # choosing the first bonding box

        # this is list comprehension
        bboxes = [
                    box for box in bboxes 
                    if box[0] != chosen_box[0] 
                    or intersection_over_union(
                        torch.tensor(chosen_box[2:]), 
                        torch.tensor(box[2:]),
                        box_format=box_format
                    ) < iou_threshold
                ]
        
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms