In [150]:
import numpy
import torch

## Non-Max Supression

#### 1) Discard all bounding boxes <code>< probability threshold</code>
#### 2) Remove all bounding boxes with <code>IOU >= threshold</code> (0.5 or 0.6)
(We do this for each class)    

### IoU

In [151]:
def IoU(boxes_preds, boxes_labels, box_format="midpoint"):
    # boxes_preds - N x 4, where N is # boxes
    if box_format == 'midpoint':
        # [x,y,w,h]
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2 # N x 1
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 2:3] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 3:4] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 2:3] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 3:4] + boxes_labels[..., 3:4] / 2
        
    if(box_format == 'corners'):
        # [x1,y1,x2,y2]
        box1_x1 = boxes_preds[..., 0:1] # N x 1
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]
        
    
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)
    
    # .clamp(0) is for the case when they DO NOT intersect. clamp(0) means set to 0 if it's less than 0
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    union = box1_area + box2_area - intersection
    
    print("\n--------------IOU-------------")
    print("intersection coordinates: [",x1.item(),y1.item(),x2.item(),y2.item(),"]")
    print("intersection:",intersection.item())
    print("union:",union.item())
    print("iou:",(intersection / (union + 1e-6)).item())
    print("-"*30,'\n')
    
    
    return (intersection / (union + 1e-6)).item()

### Non-Max Supression

In [152]:
def nms(bboxes, iou_threshold, prob_threshold, box_format="corners"):
    print("IoU threshold:",iou_threshold)
    print("prob threshold:",prob_threshold,'\n')
    
    # bboxes = [[class, Pc, x1, y1, x2, y2],[...],....]
    assert type(bboxes) == list
    
    # Bounding Boxes after Non-Max Supression to be returned
    bboxes_after_nms = []
    
    # 1) Discard all bounding boxes with Pc < prob_threshold (leave only >= prob_threshold)
    bboxes = [box for box in bboxes if box[1] >= prob_threshold]
    
    # Sort bboxes for our conveinience
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True) # sort bboxes with highest probability at the beginning
    print("Bounding Boxes after dropping <prob_threshold:")
    
    
    for b in bboxes:
        print(b)
    print('\n')
    cnt = 0
    
    while bboxes:
        # Select the bounding box with the highest probability
        highest_prob_box = bboxes.pop(0)
        print(f'Highest Prob Box {cnt}: {highest_prob_box}')
        
        # Filter - Leave only 1) bbox with difference class & 2) IOU < iou_threshold
        bboxes = [ 
            box for box in bboxes
            if 
            box[0] != highest_prob_box[0] # We don't want to remove different classes
            or # Keep only boxes with iou < threshold
            IoU(torch.tensor(highest_prob_box[2:]), torch.tensor(box[2:]), box_format=box_format) < iou_threshold
        ]
        
        bboxes_after_nms.append(highest_prob_box)
        cnt += 1
    
    return bboxes_after_nms 

### Test

In [156]:
bboxes = [[1,0.9,0,0,2,3],[1,0.7,0,1,2,3],[2,0.8,1,0,2,3],[1,0.4,1,0,2,3]]
print("\nBounding Boxes from Non-Max Supression",nms(bboxes,0.49,0.4))

IoU threshold: 0.49
prob threshold: 0.4 

Bounding Boxes after dropping <prob_threshold:
[1, 0.9, 0, 0, 2, 3]
[2, 0.8, 1, 0, 2, 3]
[1, 0.7, 0, 1, 2, 3]
[1, 0.4, 1, 0, 2, 3]


Highest Prob Box 0: [1, 0.9, 0, 0, 2, 3]

--------------IOU-------------
intersection coordinates: [ 0 1 2 3 ]
intersection: 4
union: 6
iou: 0.666666567325592
------------------------------ 


--------------IOU-------------
intersection coordinates: [ 1 0 2 3 ]
intersection: 3
union: 6
iou: 0.49999991059303284
------------------------------ 

Highest Prob Box 1: [2, 0.8, 1, 0, 2, 3]

Bounding Boxes from Non-Max Supression [[1, 0.9, 0, 0, 2, 3], [2, 0.8, 1, 0, 2, 3]]
