Skip to content
13 changes: 10 additions & 3 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,12 @@ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
return out_channels


def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
@torch.jit.unused
def _fake_cast_onnx(v: Tensor) -> int:
return v # type: ignore[return-value]


def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
"""
ONNX spec requires the k-value to be less than or equal to the number of inputs along
provided dim. Certain models use the number of elements along a particular axis instead of K
Expand All @@ -487,8 +492,10 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> Tensor:
axis(int): Axis along which we retreive the input size.

Returns:
min_kval (Tensor): Appropriately selected k-value.
min_kval (int): Appropriately selected k-value.
"""
if not torch.jit.is_tracing():
return min(orig_kval, input.size(axis))
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return min_kval # type: ignore[arg-type]
return _fake_cast_onnx(min_kval)