# Non Max Suppression Explained and PyTorch Implementation
- https://www.youtube.com/watch?v=YDkjWEN8jNA

## Cleaning up bounding boxes

### Notes
- IoU threshold
- Do NMS seperately for different class

- Discarding all the bounding boxes with probability < probability threshold
- While there are any remaining boxes :

   - Pick the box with the largest probability
   - Discard any remaining box with IOU > the IoU threshold with the box output in the previous step 

- And this is done for each class

In [3]:
import torch
from iou import intersection_over_union

In [9]:
def nms(bboxes, iou_threshold, threshold, box_format="corners"):

    """
        predictions = [
                        [1, 0.7, x1, y1, x2, y2], 
                        [c_label, pc, x1, y1, x2, y2], 
                        [...]
                    ]
    """
    assert type(bboxes) == list
    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []
    
    while bboxes:
        choosen_one = bboxes.pop(0)
        bboxes = [
            box 
            for box in bboxes 
            if box[0]!=choosen_one[0]
            or intersection_over_union(torch.tensor(choosen_one[2:]),
                                       torch.tensor(box[2:]),
                                      box_format=box_format) < iou_threshold
        ]
        bboxes_after_nms.append(choosen_one)
        
    return bboxes_after_nms

In [11]:
import sys
import unittest
import torch

class TestNonMaxSuppression(unittest.TestCase):
    def setUp(self):
        # test cases we want to run
        self.t1_boxes = [
            [1, 1, 0.5, 0.45, 0.4, 0.5],
            [1, 0.8, 0.5, 0.5, 0.2, 0.4],
            [1, 0.7, 0.25, 0.35, 0.3, 0.1],
            [1, 0.05, 0.1, 0.1, 0.1, 0.1],
        ]

        self.c1_boxes = [[1, 1, 0.5, 0.45, 0.4, 0.5], [1, 0.7, 0.25, 0.35, 0.3, 0.1]]

        self.t2_boxes = [
            [1, 1, 0.5, 0.45, 0.4, 0.5],
            [2, 0.9, 0.5, 0.5, 0.2, 0.4],
            [1, 0.8, 0.25, 0.35, 0.3, 0.1],
            [1, 0.05, 0.1, 0.1, 0.1, 0.1],
        ]

        self.c2_boxes = [
            [1, 1, 0.5, 0.45, 0.4, 0.5],
            [2, 0.9, 0.5, 0.5, 0.2, 0.4],
            [1, 0.8, 0.25, 0.35, 0.3, 0.1],
        ]

        self.t3_boxes = [
            [1, 0.9, 0.5, 0.45, 0.4, 0.5],
            [1, 1, 0.5, 0.5, 0.2, 0.4],
            [2, 0.8, 0.25, 0.35, 0.3, 0.1],
            [1, 0.05, 0.1, 0.1, 0.1, 0.1],
        ]

        self.c3_boxes = [[1, 1, 0.5, 0.5, 0.2, 0.4], [2, 0.8, 0.25, 0.35, 0.3, 0.1]]

        self.t4_boxes = [
            [1, 0.9, 0.5, 0.45, 0.4, 0.5],
            [1, 1, 0.5, 0.5, 0.2, 0.4],
            [1, 0.8, 0.25, 0.35, 0.3, 0.1],
            [1, 0.05, 0.1, 0.1, 0.1, 0.1],
        ]

        self.c4_boxes = [
            [1, 0.9, 0.5, 0.45, 0.4, 0.5],
            [1, 1, 0.5, 0.5, 0.2, 0.4],
            [1, 0.8, 0.25, 0.35, 0.3, 0.1],
        ]

    def test_remove_on_iou(self):
        bboxes = nms(
            self.t1_boxes,
            threshold=0.2,
            iou_threshold=7 / 20,
            box_format="midpoint",
        )
        self.assertTrue(sorted(bboxes) == sorted(self.c1_boxes))

    def test_keep_on_class(self):
        bboxes = nms(
            self.t2_boxes,
            threshold=0.2,
            iou_threshold=7 / 20,
            box_format="midpoint",
        )
        self.assertTrue(sorted(bboxes) == sorted(self.c2_boxes))

    def test_remove_on_iou_and_class(self):
        bboxes = nms(
            self.t3_boxes,
            threshold=0.2,
            iou_threshold=7 / 20,
            box_format="midpoint",
        )
        self.assertTrue(sorted(bboxes) == sorted(self.c3_boxes))

    def test_keep_on_iou(self):
        bboxes = nms(
            self.t4_boxes,
            threshold=0.2,
            iou_threshold=9 / 20,
            box_format="midpoint",
        )
        self.assertTrue(sorted(bboxes) == sorted(self.c4_boxes))





In [12]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_keep_on_class (__main__.TestNonMaxSuppression) ... ok
test_keep_on_iou (__main__.TestNonMaxSuppression) ... ok
test_remove_on_iou (__main__.TestNonMaxSuppression) ... ok
test_remove_on_iou_and_class (__main__.TestNonMaxSuppression) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.019s

OK


<unittest.main.TestProgram at 0x7fccf57e2610>