In [2]:
import torch

In [4]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Calculates intersection over union between two bounding boxes.

    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4) 
        boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)

    Returns:
        tensor: Intersection over union for all examples
    """

    # Grab the X, Y, W, H 
    if box_format == "midpoint":
        # 
        box1_X = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        print("\n\nBOX 1 X from MID-POINT", box1_X,"\n\n")
        box1_Y = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_W = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_H = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        
        box2_X = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_Y = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_W = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_H = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
        
    # Grab the X, Y, W, H 
    if box_format == "corners":
        box1_X = boxes_preds[..., 0:1]
        print("\n\nBOX 2 X from CORNERS", box1_X,"\n\n")
        box1_Y = boxes_preds[..., 1:2]
        box1_W = boxes_preds[..., 2:3]
        box1_H = boxes_preds[..., 3:4]  # (N, 1)
        
        box2_X = boxes_labels[..., 0:1]
        box2_Y = boxes_labels[..., 1:2]
        box2_W = boxes_labels[..., 2:3]
        box2_H = boxes_labels[..., 3:4]
        

    # grab the coordinates with the maximum value, this combines the size of 
    X = torch.max(box1_X, box2_X)
    Y = torch.max(box1_Y, box2_Y)
    W = torch.min(box1_W, box2_W)
    H = torch.min(box1_H, box2_H)

    # .clamp(0) is for the case when they do not intersect
    intersection = (W - X).clamp(0) * (H - Y).clamp(0)
    
    box1_area = abs((box1_W - box1_X) * (box1_H - box1_Y))
    box2_area = abs((box2_W - box2_X) * (box2_H - box2_Y))
    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [3]:
def non_max_suppression(
    bboxes, 
    IOU_threshold,
    prob_threshold,
    box_format="corners"):
    """ 
        performs NMS
    
    Parameters:
        bboxes (python:list) : predicted bounding boxes [ [1, 0.9, x1, y1, x2, y2], [etc..], [etc..], etc..]
            the 1 represents the class id, example: 1 means its a car
            0.9 represents the probability
        
        IOU_threshold (float) : the iou threshold when comparing bounding boxes for NMS
        
        prob_threshold (float) : the threshold to remove bounding boxes with a low confidence score
    """
    
    assert type(bboxes) == list
    
    # remove bounding boxes with a low confidence score
    bboxes = [box for box in bboxes if box[1] > prob_threshold]
    
    # sort the bboxes with the highesst probability at the beginning
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    
    bboxes_after_nms = []
    
    while bboxes:
        # grab a box from queue
        chosen_box = bboxes.pop(0)

        bboxes = [
            box for box in bboxes
                if box[0] != chosen_box[0] # check to see if the classes are the same if the bbox classes are different than we dont want to compare them IOU is only done when comparing bboxes for the same class, example : a car and a horse bbox
                or intersection_over_union(
                    torch.tensor(chosen_box[2:]), # just pass the coordinates from chosen box (x1, y1, x2, y2)
                    torch.tensor(box[2:]),
                    box_format=box_format
                )
                < IOU_threshold # if the IOU is less than the threshold then we will keep that box
        ]
        
        bboxes_after_nms.append(chosen_box)
        
    return bboxes_after_nms
        