diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index f2d4d4eb337..36e99e6506d 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -470,7 +470,12 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]: return out_channels -def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor: +@torch.jit.unused +def _fake_cast_onnx(v: Tensor) -> int: + return v # type: ignore[return-value] + + +def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int: """ ONNX spec requires the k-value to be less than or equal to the number of inputs along provided dim. Certain models use the number of elements along a particular axis instead of K @@ -487,8 +492,10 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor: axis(int): Axis along which we retreive the input size. Returns: - min_kval (Tensor): Appropriately selected k-value. + min_kval (int): Appropriately selected k-value. """ + if not torch.jit.is_tracing(): + return min(orig_kval, input.size(axis)) axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) - return min_kval # type: ignore[arg-type] + return _fake_cast_onnx(min_kval)