diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 51290d6e48c..bce7d046780 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,4 +1,12 @@ import torch +from torch import Tensor + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.dtype not in (torch.float32, torch.float64): + return t.float() + return t def generalized_box_iou_loss( @@ -34,6 +42,8 @@ def generalized_box_iou_loss( https://arxiv.org/abs/1902.09630 """ + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)