diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 51290d6e48c..c320384793c 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -37,6 +37,13 @@ def generalized_box_iou_loss( x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + # degenerate boxes gives inf / nan results + # so do an early check + if not ((x2 >= x1) & (y2 >= y1)).all(): + raise ValueError("Some of the input boxes1 are invalid.") + if not ((x2g >= x1g) & (y2g >= y1g)).all(): + raise ValueError("Some of the input boxes2 are invalid.") + # Intersection keypoints xkis1 = torch.max(x1, x1g) ykis1 = torch.max(y1, y1g)