Skip to content

Commit

Permalink
Add autocast for nms, roi_align on CPU (#8049)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people committed Oct 27, 2023
1 parent b80bdb7 commit 8a0b491
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 12 deletions.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
Expand Down Expand Up @@ -184,8 +186,6 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))

source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))

sources = main_file + source_cpu
extension = CppExtension

Expand Down
26 changes: 26 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, determinist
tol = 5e-3
else:
tol = 4e-3
elif x_dtype == torch.bfloat16:
tol = 5e-3

pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
Expand Down Expand Up @@ -504,6 +506,21 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cpu.amp.autocast():
self.test_forward(
torch.device("cpu"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
Expand Down Expand Up @@ -808,6 +825,15 @@ def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
boxes, scores = self._create_tensors_with_iou(1000, iou)
with torch.cpu.amp.autocast():
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
torch.testing.assert_close(keep_ref_float, keep_dtype)

@pytest.mark.parametrize(
"device",
(
Expand Down
20 changes: 16 additions & 4 deletions torchvision/csrc/ops/autocast/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,33 @@ namespace ops {

namespace {

template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);

return nms(
at::autocast::cached_cast(at::kFloat, dets),
at::autocast::cached_cast(at::kFloat, scores),
at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores, device_type),
iou_threshold);
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down
19 changes: 15 additions & 4 deletions torchvision/csrc/ops/autocast/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace ops {

namespace {

template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -17,10 +18,10 @@ at::Tensor roi_align_autocast(
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
at::autocast::cached_cast(at::kFloat, input, device_type),
at::autocast::cached_cast(at::kFloat, rois, device_type),
spatial_scale,
pooled_height,
pooled_width,
Expand All @@ -34,7 +35,17 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
TORCH_FN((roi_align_autocast<
c10::DispatchKey::Autocast,
c10::DeviceType::CUDA>)));
}

TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<
c10::DispatchKey::AutocastCPU,
c10::DeviceType::CPU>)));
}

} // namespace ops
Expand Down

0 comments on commit 8a0b491

Please sign in to comment.