In [None]:
import numpy as np

## Test Data

In [141]:
boxes_with_full_intersection = np.array([
    [1, 1, 3, 3, 0.95],
    [1, 1, 3, 3, 0.93]
])

boxes_with_no_intersection = np.array([
    [1, 1, 3, 3, 0.95],
    [4, 4, 5, 5, 0.93]
])

boxes_with_one_third_intersection = np.array([
    [1, 1, 3, 3, 0.95],
    [2, 1, 4, 3, 0.93]
])

general_test_case = np.array([
    [1, 1, 3, 3, 0.95],
    [1, 1, 3, 4, 0.93],
    [1, 0.9, 3.6, 3, 0.98],
    [1, 0.9, 3.5, 3, 0.97]])

## Write function to compute iou

In [169]:
def get_iou(box, boxes):
    ''' 
    Assumption is boxes[:, :2] in boxes points to top left corner and boxes[:, 2:4] points to bottom right corner
    box - shape 5
    boxes - shape n, 5
    return - shape 5
    '''
    # Compute the max of the top left coordinate between box and each boxes
    top_left_x = np.maximum(np.expand_dims(box[0], 0), boxes[:, 0])
    top_left_y = np.maximum(np.expand_dims(box[1], 0), boxes[:, 1])
    
    # Compute the min of the bottom right coordinate between box and each boxes
    bottom_right_x = np.minimum(np.expand_dims(box[2], 0), boxes[:, 2])
    bottom_right_y = np.minimum(np.expand_dims(box[3], 0), boxes[:, 3])
        
    # Compute the width and height of the intersection clamping to 0 if negative
    width = bottom_right_x - top_left_x
    width[width < 0] = 0
    
    height = bottom_right_y - top_left_y
    height[height < 0] = 0
        
    # Compute the intersection and union of the area
    intersection = width * height
    area_of_box = np.expand_dims((box[3] - box[1]) * (box[2] - box[0]), 0)
    union = np.add((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - intersection, area_of_box)

    # Compute the IOU and make sure you dont divide by 0
    iou = intersection / (union + 1e-8)
    
    return iou

In [170]:
def test_iou():
    iou = get_iou(boxes[0], boxes_with_full_intersection)
    assert (np.allclose(iou, np.array([1, 1])))
    
    iou = get_iou(boxes[0], boxes_with_no_intersection)
    assert (np.allclose(iou, np.array([1, 0])))
    
    iou = get_iou(boxes[0], boxes_with_half_intersection)
    assert (np.allclose(iou, np.array([1, 0.3333333])))

## Write function to perform NMS

In [195]:
def nms(boxes, threshold):
    """
    boxes - shape is (:, 5)
    """
    boxes_to_return = []
    
    # Sort the boxes based on confidence
    sorted_by_conf = np.argsort(boxes[:, -1]) 
    # Loop over all the boxes till boxes is not empty
    while sorted_by_conf.shape[0] != 0:
        box_to_add = boxes[sorted_by_conf[-1]]
        
        boxes_to_return.append(box_to_add)
        
        sorted_by_conf = sorted_by_conf[:-1]
        
        if sorted_by_conf.shape[0] == 0:
            break
        
        iou = get_iou(box_to_add, boxes[sorted_by_conf])
        mask = iou < threshold
        sorted_by_conf = sorted_by_conf[mask]
        
    return np.array(boxes_to_return)

In [196]:
thresh = 0.3
def test_nms():

    boxes = nms(boxes_with_full_intersection, thresh)
    assert (boxes == boxes_with_full_intersection[0]).all()
    
    boxes = nms(boxes_with_no_intersection, thresh)
    assert (boxes == boxes_with_no_intersection).all()
    
    boxes = nms(boxes_with_one_third_intersection, thresh)
    assert (boxes == boxes_with_one_third_intersection[0]).all()
    
    boxes = nms(general_test_case, 0.5)
    assert (boxes == general_test_case[2]).all()
    
    boxes = nms(general_test_case, 0.8)
    assert (boxes == general_test_case[[2, 0, 1]]).all()

In [199]:
test_nms()