In [5]:
import torch

In [None]:
"""

Calculates intersection over union 
Parameters: 
    boxes preds (tensor): Predictions of Bounding Boxes (BATCH SIZE, 4) 
    boxes labels (tensor): Correct labels of Bounding Boxes (BATCH SIZE, 4) 
    box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2) 

Returns: 
    tensor: Intersection over union for all examples 

"""

In [None]:
def intersection_over_union(boxes_pred, boxes_labels, box_format="midpoint"):
    # boxes_pred shape is (N, 4) where N is the number of boxes
    # boxes_labes(true labels) shape is (N, 4) where N is the number of boxes
    # here both passed arguments are "torch tensor"

    # boxes_pred[..., 0:1] means extracting the first column and from the last dimension

    if box_format == "midpoint":
        # (x, y, w, h) where x any y are mid points(coordinates) of the box and w and h are width and height of the box
        box1_x1 = boxes_pred[..., 0:1] - boxes_pred[..., 2:3]/2
        box1_y1 = boxes_pred[..., 1:2] - boxes_pred[..., 3:4]/2
        box1_x2 = boxes_pred[..., 0:1] + boxes_pred[..., 2:3]/2
        box1_y2 = boxes_pred[..., 1:2] + boxes_pred[..., 3:4]/2

        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3]/2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4]/2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3]/2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4]/2


    elif box_format == "corners":
        # here corner coordinates are given
        box1_x1 = boxes_pred[..., 0:1] # its shape is (N, 1)
        box1_y1 = boxes_pred[..., 1:2]
        box1_x2 = boxes_pred[..., 2:3]
        box1_y2 = boxes_pred[..., 3:4]

        box2_x1 = boxes_labels[..., 0:1] # its shape is (N, 1)
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    # corner coordinates of intersection area
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)

    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # calculating intersection area

    # clamp(min=0) in pytorch makes all negative values as 0
    # here its is used for the edge cases where boxes do not intersect 
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    iOu = intersection / (box1_area + box2_area - intersection + 1e-6)
    # 1e-6 is for numerical stability

    return iOu



