From 2f8089f0e2ae374b91f7b596b4a76a44bdbad838 Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sat, 26 Mar 2022 22:24:33 +0900 Subject: [PATCH 1/4] call _upcast to consider overflow giou loss is weak at overflow problem because it computes area of box. https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py#L226-L242 --- torchvision/ops/giou_loss.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 51290d6e48c..54fddd927f0 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,4 +1,13 @@ import torch +from torch import Tensor + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() def generalized_box_iou_loss( @@ -34,6 +43,8 @@ def generalized_box_iou_loss( https://arxiv.org/abs/1902.09630 """ + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) From dc2d754410e25557edc233c54b99b9f3247684fb Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sun, 27 Mar 2022 17:37:48 +0900 Subject: [PATCH 2/4] cast datatype to float --- torchvision/ops/giou_loss.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 54fddd927f0..8ec38cded98 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -4,10 +4,9 @@ def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() + if not t.dtype in (torch.float32, torch.float64) + return t.float() + return t def generalized_box_iou_loss( From e6a7769b660e8133cfd221cfa15b9b64ef4f8bdc Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sun, 27 Mar 2022 17:40:56 +0900 Subject: [PATCH 3/4] add ":" to if --- torchvision/ops/giou_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 8ec38cded98..3f79487b2f1 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -4,7 +4,7 @@ def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if not t.dtype in (torch.float32, torch.float64) + if not t.dtype in (torch.float32, torch.float64): return t.float() return t From ad704edb43282bd17eff04e62124c67b4a821c39 Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sun, 27 Mar 2022 18:37:08 +0900 Subject: [PATCH 4/4] lint Test for membership should be 'not in' (E713) https://www.flake8rules.com/rules/E713.html --- torchvision/ops/giou_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 3f79487b2f1..bce7d046780 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -4,7 +4,7 @@ def _upcast(t: Tensor) -> Tensor: # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if not t.dtype in (torch.float32, torch.float64): + if t.dtype not in (torch.float32, torch.float64): return t.float() return t