In [1]:
import os
# Changes the current working directory so to mimic running from projects root, i.e (./yolo_v1_taco).
# os.getcwd().rsplit("/", 1)[0]
os.chdir(os.getcwd().rsplit("/", 1)[0])
print(os.getcwd())

/Users/tonyavis/Main/AI_public_projects/object_detection/yolo_v1_taco


In [None]:
import torch
from utils.intersection_over_union import intersection_over_union

def vectorized_nms(t: torch.Tensor, config):
    """
    Perform NMS by filtering out overlapping bboxes per class.
    Input shape: (N, 9) with fields [i, j, b, class_idx, pc, x, y, w, h]
    Output shape: (M, 9) with filtered bboxes.
    """
    DEVICE, IOU_THRESHOLD= config.DEVICE, config.IOU_THRESHOLD

    # store the bboxes that pass IOU
    output = []

    # --- 1: Loop thru the number of unique class_idx
    for cls_idx in t[:, 3].unique():

        # --- 2: Use a mask so we can put the bboxes with the same class_idx into a tensor.
        class_mask = (t[:, 3] == cls_idx)
        bboxes = t[class_mask] # bboxes -> is a tensor that contains bboxes with the same class_idx!
        keep = [] # Stores bboxes to keep

        # --- 3: Queue -> Loop thru the bboxes of the same class_idx.
        while len(bboxes) > 0:
            """This queue works like so
                    1. bboxes = [box1, box2, box3, etc..] all of the same class_idx
                    2. chosen_box = box1
                    3. box1 is add to keep
                    3. box1 is compared with all the other boxes vectorized
                    4. if any box# overlap too much/doesn't pass IOU with box1, then those boxes are removed from queue list.
                    5. then we loop back up, and chosen_box is the next box in queue that didn't overlap with box1 etc..
            """

            # --- 4: Get the first bbox, which will always be the bbox with the highest pc for every class.
            chosen_bbox = bboxes[0]
            keep.append(chosen_bbox) # since it has the highest pc safe to keep it

            # Handle final case
            if len(bboxes) == 1:
                break

            # Pop the chosen_bbox, so we get a tensor with the rest of the bboxes of the same class_idx.
            rest = bboxes[1:]

            #  --- 5: Compute IOU
            iou = intersection_over_union(chosen_bbox=chosen_bbox, rest_bbox=rest)

            # --- 6: Remove overlapping bboxes 
            bboxes = rest[iou < IOU_THRESHOLD]

        # Add valid bboxes to output
        output.extend(keep)
    if len(output) == 0:
        print("OUTPUT LENGTH EMPTY")
        torch.empty((0, 9), device=DEVICE)
    return torch.stack(output, dim=0)

In [3]:
from types import SimpleNamespace

# Simulated config
config = SimpleNamespace(DEVICE="cpu", IOU_THRESHOLD=0.5)

# Create test tensor [i, j, b, class_idx, pc, x, y, w, h]
test_tensor = torch.tensor(
    [
        [0, 0, 0, 1, 0.95, 0.5, 0.5, 0.4, 0.4],  # keep
        [0, 0, 1, 1, 0.85, 0.52, 0.52, 0.4, 0.4],  # suppress (overlaps)
        [0, 0, 0, 1, 0.30, 0.9, 0.9, 0.3, 0.3],  # keep (low overlap)
        [0, 1, 1, 2, 0.88, 0.5, 0.5, 0.2, 0.2],  # keep (diff class)
        [0, 1, 0, 2, 0.70, 0.51, 0.51, 0.2, 0.2],  # suppress (same class + overlaps)
    ]
)

# Run
filtered = vectorized_nms(test_tensor, config)

# Print results
print("Filtered BBoxes:")
print(filtered)

Filtered BBoxes:
tensor([[0.0000, 0.0000, 0.0000, 1.0000, 0.9500, 0.5000, 0.5000, 0.4000, 0.4000],
        [0.0000, 0.0000, 0.0000, 1.0000, 0.3000, 0.9000, 0.9000, 0.3000, 0.3000],
        [0.0000, 1.0000, 1.0000, 2.0000, 0.8800, 0.5000, 0.5000, 0.2000, 0.2000]])
