Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bfb8321
NMS implementation for CPU
Mar 23, 2026
8875875
add tests
Mar 23, 2026
ae5fb41
Merge remote-tracking branch 'origin/main' into rotated-NMS
Mar 23, 2026
0e924ed
fuse two implementations
Mar 24, 2026
667e1b4
Remove Detectron2 license header from nms_kernel.cpp since the NMS al…
Mar 24, 2026
f01ca8b
revert the unnecessary linting
Mar 24, 2026
40adc7f
Preserve original IoU computation in fused NMS implementation
Mar 24, 2026
eb4aefe
Add box caching in fused NMS to match original memory access pattern
Mar 24, 2026
5499015
fix the test failure
Mar 24, 2026
c940345
remove the torchscript test and err_msg
Mar 25, 2026
17a54cc
Remove the fmt parameter from nms function
Mar 25, 2026
54aee4c
fix the test failures
Mar 25, 2026
c5cf1c4
Reuse TestNMS._reference_nms in TestNMSRotated
Mar 27, 2026
9c96e26
address the comment
Mar 27, 2026
4a993c8
change the variable name
Mar 27, 2026
84e960c
address more comments on the file torchvision/csrc/ops/cpu/nms_kernel…
Mar 27, 2026
9ef0c0b
add batched_nms and the test
Mar 27, 2026
4da74a8
Rename _reference_nms to _reference_aligned_nms
Mar 31, 2026
2ea193b
Merge TestNMSRotated into TestNMS
Mar 31, 2026
c249d35
combine tests, refactor duplicated parts and add a test for NMS with …
Apr 1, 2026
50afe5c
add test_nms_rotated_specific_angles
Apr 1, 2026
8259471
parametrize the existing test_batched_nms_implementations test over t…
Apr 1, 2026
bdd502d
address the comments by using the torch.testing.assert_close and remo…
Apr 3, 2026
7f5d72b
parametrizing the angles for the test_batched_nms_rotated function
Apr 3, 2026
aea073c
add explanation for angle = 90 and move it out of the _create_rotated…
Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 91 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def test_is_leaf_node(self, device):


class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
def _reference_aligned_nms(cls, boxes, scores, iou_threshold):
"""
Args:
boxes: boxes in corner-form
Expand Down Expand Up @@ -818,15 +818,15 @@ def test_nms_ref(self, iou, seed):
torch.random.manual_seed(seed)
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self._reference_nms(boxes, scores, iou)
keep_ref = self._reference_aligned_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))

def test_nms_input_errors(self):
with pytest.raises(RuntimeError):
ops.nms(torch.rand(4), torch.rand(3), 0.5)
with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
with pytest.raises((RuntimeError, ValueError)):
ops.nms(torch.rand(3, 6), torch.rand(3), 0.5)
with pytest.raises(RuntimeError):
ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -920,19 +920,23 @@ def test_nms_float16(self, device):
assert_equal(keep32, keep16)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("rotated", (False, True))
@pytest.mark.opcheck_only_one()
def test_batched_nms_implementations(self, seed):
def test_batched_nms_implementations(self, seed, rotated):
"""Make sure that both implementations of batched_nms yield identical results"""
torch.random.manual_seed(seed)

num_boxes = 1000
iou_threshold = 0.9

boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
if rotated:
_, boxes, scores = self._create_rotated_boxes(num_boxes)
else:
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
scores = torch.rand(num_boxes)

scores = torch.rand(num_boxes)
idxs = torch.randint(0, 4, size=(num_boxes,))
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
Expand All @@ -945,6 +949,84 @@ def test_batched_nms_implementations(self, seed):
empty = torch.empty((0,), dtype=torch.int64)
torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))

def _create_rotated_boxes(self, N, angle=0, device="cpu"):
boxes = torch.rand(N, 4, device=device) * 200
boxes[:, 2:] += boxes[:, :2]
scores = torch.rand(N, device=device)
cxcywh = ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh")
rotated_boxes = torch.zeros(N, 5, device=device)
rotated_boxes[:, :4] = cxcywh
rotated_boxes[:, 4] = angle
return boxes, rotated_boxes, scores

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("angle", (0, 90, 180))
def test_nms_rotated(self, iou, angle):
torch.manual_seed(0)
N = 1000
boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle)
if angle == 90:
# widths and heights are intentionally swapped here for 90 degrees case
# so that the reference horizontal nms could be used
rotated_boxes[:, 2], rotated_boxes[:, 3] = (
rotated_boxes[:, 3].clone(),
rotated_boxes[:, 2].clone(),
)
keep_ref = self._reference_aligned_nms(boxes, scores, iou)
keep = ops.nms(rotated_boxes, scores, iou)
torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0)
keep_non_rotated = ops.nms(boxes, scores, iou)
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("angle", (0, 90, 180))
def test_batched_nms_rotated(self, iou, angle):
torch.manual_seed(0)
N = 2000
num_classes = 50
boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle)
if angle == 90:
# widths and heights are intentionally swapped here for 90 degrees case
# so that the reference horizontal nms could be used
rotated_boxes[:, 2], rotated_boxes[:, 3] = (
rotated_boxes[:, 3].clone(),
rotated_boxes[:, 2].clone(),
)
idxs = torch.randint(0, num_classes, (N,))
backup = rotated_boxes.clone()
keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou)
keep = ops.batched_nms(rotated_boxes, scores, idxs, iou)
torch.testing.assert_close(rotated_boxes, backup)
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_rotated_different_angles(self, iou):
torch.manual_seed(0)
N = 1000
_, rotated_boxes, scores = self._create_rotated_boxes(N)
rotated_boxes[:, 4] = torch.rand(N) * 360
keep = ops.nms(rotated_boxes, scores, iou)
assert keep.dtype == torch.int64
assert keep.dim() == 1
assert keep.numel() <= N
assert (keep >= 0).all() and (keep < N).all()
assert (scores[keep][:-1] >= scores[keep][1:]).all()

def test_nms_rotated_specific_angles(self):
boxes = torch.tensor(
[
[0, 0, 10, 10, 0],
[0, 0, 10, 10, 45],
[100, 100, 10, 10, 30],
],
dtype=torch.float32,
)
scores = torch.tensor([0.9, 0.8, 0.7])
keep = ops.nms(boxes, scores, iou_threshold=0.5)
assert 0 in keep.tolist()
assert 1 not in keep.tolist()
assert 2 in keep.tolist()


optests.generate_opcheck_tests(
testcase=TestNMS,
Expand Down
10 changes: 10 additions & 0 deletions torchvision/_autograd_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ def _autocast_nms(dets, scores, iou_threshold):
)


def _autocast_nms_rotated(dets, scores, iou_threshold):
with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys):
return torch.ops.torchvision.nms_rotated(
_autocast_cast(dets),
_autocast_cast(scores),
iou_threshold,
)


def _autocast_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys):
Expand Down Expand Up @@ -358,6 +367,7 @@ def _autocast_deform_conv2d(
# nms and roi_align: registered for all autocast device types
for _key in ("AutocastCUDA", "AutocastCPU", "AutocastXPU"):
_autocast_lib.impl("nms", _autocast_nms, _key)
_autocast_lib.impl("nms_rotated", _autocast_nms_rotated, _key)
_autocast_lib.impl("roi_align", _autocast_roi_align, _key)

# Other ops: CUDA autocast only
Expand Down
14 changes: 14 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ def meta_nms(dets, scores, iou_threshold):
return dets.new_empty(num_to_keep, dtype=torch.long)


@torch.library.register_fake("torchvision::nms_rotated")
def meta_nms_rotated(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 5, lambda: f"boxes should have 5 elements in dimension 1, got {dets.size(1)}")
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
torch._check(
dets.size(0) == scores.size(0),
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
)
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)


@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
Expand Down
145 changes: 116 additions & 29 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <ATen/ATen.h>
#include <torch/library.h>

#include "../box_iou_rotated_utils.h"

namespace vision {
namespace ops {

namespace {

template <typename scalar_t>
template <typename scalar_t, typename IoUFunc>
at::Tensor nms_kernel_impl(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
double iou_threshold,
IoUFunc iou_func) {
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK(
Expand All @@ -21,13 +30,6 @@ at::Tensor nms_kernel_impl(
return at::empty({0}, dets.options().dtype(at::kLong));
}

auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();

at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);

auto order_t = std::get<1>(
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));

Expand All @@ -38,11 +40,6 @@ at::Tensor nms_kernel_impl(
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
auto x1 = x1_t.data_ptr<scalar_t>();
auto y1 = y1_t.data_ptr<scalar_t>();
auto x2 = x2_t.data_ptr<scalar_t>();
auto y2 = y2_t.data_ptr<scalar_t>();
auto areas = areas_t.data_ptr<scalar_t>();

int64_t num_to_keep = 0;

Expand All @@ -52,26 +49,16 @@ at::Tensor nms_kernel_impl(
continue;
}
keep[num_to_keep++] = i;
auto ix1 = x1[i];
auto iy1 = y1[i];
auto ix2 = x2[i];
auto iy2 = y2[i];
auto iarea = areas[i];

iou_func.set_box(i);

for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto xx1 = std::max(ix1, x1[j]);
auto yy1 = std::max(iy1, y1[j]);
auto xx2 = std::min(ix2, x2[j]);
auto yy2 = std::min(iy2, y2[j]);

auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);

auto ovr = iou_func.compute(j);
if (ovr > iou_threshold) {
suppressed[j] = 1;
}
Expand All @@ -80,6 +67,70 @@ at::Tensor nms_kernel_impl(
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}

template <typename scalar_t>
struct NonRotatedIoU {
const scalar_t* x1;
const scalar_t* y1;
const scalar_t* x2;
const scalar_t* y2;
const scalar_t* areas;
at::Tensor x1_t, y1_t, x2_t, y2_t, areas_t;

scalar_t ix1, iy1, ix2, iy2, iarea;

NonRotatedIoU(const at::Tensor& dets) {
x1_t = dets.select(1, 0).contiguous();
y1_t = dets.select(1, 1).contiguous();
x2_t = dets.select(1, 2).contiguous();
y2_t = dets.select(1, 3).contiguous();
areas_t = (x2_t - x1_t) * (y2_t - y1_t);
x1 = x1_t.data_ptr<scalar_t>();
y1 = y1_t.data_ptr<scalar_t>();
x2 = x2_t.data_ptr<scalar_t>();
y2 = y2_t.data_ptr<scalar_t>();
areas = areas_t.data_ptr<scalar_t>();
}

void set_box(int64_t i) {
ix1 = x1[i];
iy1 = y1[i];
ix2 = x2[i];
iy2 = y2[i];
iarea = areas[i];
}

scalar_t compute(int64_t j) const {
auto xx1 = std::max(ix1, x1[j]);
auto yy1 = std::max(iy1, y1[j]);
auto xx2 = std::min(ix2, x2[j]);
auto yy2 = std::min(iy2, y2[j]);

auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
return inter / (iarea + areas[j] - inter);
}
};

template <typename scalar_t>
struct RotatedIoU {
const at::Tensor* dets_ptr;

RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {}

int64_t i;

void set_box(int64_t i) {
this->i = i;
}

scalar_t compute(int64_t j) const {
return single_box_iou_rotated<scalar_t>(
(*dets_ptr)[i].template data_ptr<scalar_t>(),
(*dets_ptr)[j].template data_ptr<scalar_t>());
Comment on lines +129 to +130
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the .template syntax?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added .template to fix a CI failure we hit on the macOS build. It explicitly indicates that data_ptr<scalar_t> is a template method. Without it, the compiler can't tell whether data_ptr<scalar_t> is a template method call or a less-than comparison.

}
};

at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
Expand All @@ -106,7 +157,40 @@ at::Tensor nms_kernel(
auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
result = nms_kernel_impl<scalar_t>(
dets, scores, iou_threshold, NonRotatedIoU<scalar_t>(dets));
});
return result;
}

at::Tensor nms_rotated_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 5,
"boxes should have 5 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0));

auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_kernel", [&] {
result = nms_kernel_impl<scalar_t>(
dets, scores, iou_threshold, RotatedIoU<scalar_t>(dets));
});
return result;
}
Expand All @@ -115,6 +199,9 @@ at::Tensor nms_kernel(

TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms_rotated"),
TORCH_FN(nms_rotated_kernel));
}

} // namespace ops
Expand Down
Loading
Loading