From 8982ff6f3afa9031ddac80bad1402114df331486 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 2 Apr 2020 10:55:57 -0700 Subject: [PATCH 1/2] Fixing nms on boxes when no detection --- torchvision/__init__.py | 3 ++- torchvision/ops/boxes.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index ae81aaacb60..b10559cd022 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,5 +1,7 @@ import warnings +from .extension import _HAS_OPS + from torchvision import models from torchvision import datasets from torchvision import ops @@ -7,7 +9,6 @@ from torchvision import utils from torchvision import io -from .extension import _HAS_OPS import torch try: diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 44dee79497f..e664f46e0df 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -5,6 +5,7 @@ from torch import Tensor +@torch.jit.script def nms(boxes, scores, iou_threshold): # type: (Tensor, Tensor, float) """ @@ -36,6 +37,7 @@ def nms(boxes, scores, iou_threshold): return torch.ops.torchvision.nms(boxes, scores, iou_threshold) +@torch.jit.script def batched_nms(boxes, scores, idxs, iou_threshold): # type: (Tensor, Tensor, Tensor, float) """ From 21f94fd7334f83c2aa8a9bdd4c32b2cb412185d6 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 8 Apr 2020 14:40:55 -0700 Subject: [PATCH 2/2] test --- torchvision/models/detection/transform.py | 33 ++++++++++++++--------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 9b2ef009cb8..f542f8d755d 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -12,6 +12,25 @@ from .roi_heads import paste_masks_in_image +@torch.jit.script +def compute_scale_factor(image, self_min_size, self_max_size): + h, w = image.shape[-2:] + im_shape = torch.tensor(image.shape[-2:]) + min_size = float(torch.min(im_shape)) + max_size = float(torch.max(im_shape)) + + # FIXME assume for now that testing uses the largest scale + size = float(self_min_size[-1]) + scale_factor = size / min_size + if max_size * scale_factor > self_max_size: + scale_factor = self_max_size / max_size + return scale_factor + image = torch.nn.functional.interpolate( + image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] + + return image + + class GeneralizedRCNNTransform(nn.Module): """ Performs input / target transformation before feeding the data to a GeneralizedRCNN @@ -78,19 +97,7 @@ def torch_choice(self, l): def resize(self, image, target): # type: (Tensor, Optional[Dict[str, Tensor]]) h, w = image.shape[-2:] - im_shape = torch.tensor(image.shape[-2:]) - min_size = float(torch.min(im_shape)) - max_size = float(torch.max(im_shape)) - if self.training: - size = float(self.torch_choice(self.min_size)) - else: - # FIXME assume for now that testing uses the largest scale - size = float(self.min_size[-1]) - scale_factor = size / min_size - if max_size * scale_factor > self.max_size: - scale_factor = self.max_size / max_size - image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] + image = compute_scale_factor(image, self.min_size, self.max_size) if target is None: return image, target