In [1]:
from typing import Tuple
import torch
from torch import Tensor
from torchvision.ops import box_iou


def iou_argmax(widthheight1: Tensor,
               widthheight2: Tensor) -> Tuple[Tensor, Tensor]:
    """
    Matches boxes between 2 sets based on intersection-over-union (IOU).

    Args:
        widthheight1: M x 2 matrix of box width, box height
        widthheight2: N x 2 matrix of box width, box height

    Returns: pairwise IOU, M-vector of which set 2 box matches each set 1 box
    """
    boxes1 = torch.stack([
        -widthheight1[:, 0] / 2,
        -widthheight1[:, 1] / 2,
        widthheight1[:, 0] / 2,
        widthheight1[:, 1] / 2
    ], dim=1)
    boxes2 = torch.stack([
        -widthheight2[:, 0] / 2,
        -widthheight2[:, 1] / 2,
        widthheight2[:, 0] / 2,
        widthheight2[:, 1] / 2
    ], dim=1)
    iou = box_iou(boxes1, boxes2)  # expects xyxy format
    return iou, iou.argmax(dim=1)