diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index ef4f6550eef..f2d4d4eb337 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -468,3 +468,27 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: model.train() return out_channels + + +def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor: + """ + ONNX spec requires the k-value to be less than or equal to the number of inputs along + provided dim. Certain models use the number of elements along a particular axis instead of K + if K exceeds the number of elements along that axis. Previously, python's min() function was + used to determine whether to use the provided k-value or the specified dim axis value. + + However in cases where the model is being exported in tracing mode, python min() is + static causing the model to be traced incorrectly and eventually fail at the topk node. + In order to avoid this situation, in tracing mode, torch.min() is used instead. + + Args: + input (Tensor): The orignal input tensor. + orig_kval (int): The provided k-value. + axis(int): Axis along which we retreive the input size. + + Returns: + min_kval (Tensor): Appropriately selected k-value. + """ + axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) + min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) + return min_kval # type: ignore[arg-type] diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 91baf1d0b29..32a413e9cb1 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -501,7 +501,7 @@ def postprocess_detections( topk_idxs = torch.where(keep_idxs)[0] # keep only topk scoring predictions - num_topk = min(self.topk_candidates, topk_idxs.size(0)) + num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0) scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index e5ced9870ba..1909f6a8b73 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -436,7 +436,7 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): topk_idxs = torch.where(keep_idxs)[0] # keep only topk scoring predictions - num_topk = min(self.topk_candidates, topk_idxs.size(0)) + num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0) scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 15cec706fbb..1d63bcc8a54 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,7 +1,6 @@ -from typing import List, Optional, Dict, Tuple, cast +from typing import List, Optional, Dict, Tuple import torch -import torchvision from torch import nn, Tensor from torch.nn import functional as F from torchvision.ops import boxes as box_ops @@ -13,17 +12,6 @@ from .image_list import ImageList -@torch.jit.unused -def _onnx_get_num_anchors_and_pre_nms_top_n(ob: Tensor, orig_pre_nms_top_n: int) -> Tuple[int, int]: - from torch.onnx import operators - - num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) - pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) - - # for mypy we cast at runtime - return cast(int, num_anchors), cast(int, pre_nms_top_n) - - class RPNHead(nn.Module): """ Adds a simple RPN Head with classification and regression heads @@ -206,11 +194,8 @@ def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) - r = [] offset = 0 for ob in objectness.split(num_anchors_per_level, 1): - if torchvision._is_tracing(): - num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) - else: - num_anchors = ob.shape[1] - pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) + num_anchors = ob.shape[1] + pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) r.append(top_n_idx + offset) offset += num_anchors diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 5778a07075d..08a9ed68e4e 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -407,7 +407,7 @@ def postprocess_detections( box = boxes[keep_idxs] # keep only topk scoring predictions - num_topk = min(self.topk_candidates, score.size(0)) + num_topk = det_utils._topk_min(score, self.topk_candidates, 0) score, idxs = score.topk(num_topk) box = box[idxs]