In [None]:
import torch

In [20]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):

  if box_format == "midpoint":
    box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
    box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
    box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
    box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
    box2_x1 = boxes_labels[..., 0:1] - boxes_preds[..., 2:3] / 2
    box2_y1 = boxes_labels[..., 1:2] - boxes_preds[..., 3:4] / 2
    box2_x2 = boxes_labels[..., 0:1] + boxes_preds[..., 2:3] / 2
    box2_y2 = boxes_labels[..., 1:2] + boxes_preds[..., 3:4] / 2

  if box_format == "corners":
    box1_x1 = boxes_preds[..., 0:1]
    box1_y1 = boxes_preds[..., 1:2]
    box1_x2 = boxes_preds[..., 2:3]
    box1_y2 = boxes_preds[..., 3:4]

    box2_x1 = boxes_labels[..., 0:1]
    box2_y1 = boxes_labels[..., 1:2]
    box2_x2 = boxes_labels[..., 2:3]
    box2_y2 = boxes_labels[..., 3:4]

  x1 = torch.max(box1_x1, box2_x1)
  y1 = torch.max(box1_y1, box2_y1)
  x2 = torch.min(box1_x2, box2_x2)
  y2 = torch.min(box1_y2, box2_y2)

  intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
  box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
  box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

  return intersection / (box1_area + box2_area - intersection + 1e-6)


In [23]:
t1_box1 = torch.tensor([0.8, 0.1, 0.2, 0.2])
t1_box2 = torch.tensor([0.9, 0.2, 0.2, 0.2])
t1_correct_iou = 1 / 7

iou = intersection_over_union(t1_box1, t1_box2, box_format="midpoint")
iou, t1_correct_iou

(tensor([0.1429]), 0.14285714285714285)

In [22]:
t6_box1 = torch.tensor([2, 2, 6, 6])
t6_box2 = torch.tensor([4, 4, 7, 8])
t6_correct_iou = 4 / 24

iou = intersection_over_union(t6_box1, t6_box2, box_format="corners")
iou, t6_correct_iou

(tensor([0.1667]), 0.16666666666666666)