From fc7d027f02ae83477e1141b44d27d57432f1fc3a Mon Sep 17 00:00:00 2001 From: mcarilli Date: Thu, 9 Jul 2020 04:05:48 -0600 Subject: [PATCH] [WIP] Allow autocast for 1.6 (#2384) * Fixes Xiao's repro * Ports nms to use full dispatcher * Move HIPGuard to nms_cuda * clang-format * run models in test_models.py on GPU if available * Francisco's comment, also disable cuda model tests to see if CPU alone still passes * cuda tests now pass locally, although still not comparing to saved numerics * add note for thing to ask francisco * Allow cuda and cpu tests to share a data file * ignore suffix if unneeded * Skip autocast numerics checks for a few models * Add roi_align test Co-authored-by: Michael Carilli --- test/common_utils.py | 16 ++- test/test_models.py | 206 +++++++++++++++++++--------- test/test_ops.py | 20 ++- torchvision/csrc/ROIAlign.h | 24 ++++ torchvision/csrc/autocast.h | 28 ++++ torchvision/csrc/cpu/nms_cpu.cpp | 23 +++- torchvision/csrc/cpu/vision_cpu.h | 2 +- torchvision/csrc/cuda/nms_cuda.cu | 37 ++++- torchvision/csrc/cuda/vision_cuda.h | 7 +- torchvision/csrc/nms.h | 55 +++----- torchvision/csrc/vision.cpp | 12 +- torchvision/ops/poolers.py | 9 +- 12 files changed, 314 insertions(+), 125 deletions(-) create mode 100644 torchvision/csrc/autocast.h diff --git a/test/common_utils.py b/test/common_utils.py index 9dbd04f4217..d3b6e97a6dc 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -85,7 +85,7 @@ def is_iterable(obj): class TestCase(unittest.TestCase): precision = 1e-5 - def assertExpected(self, output, subname=None, prec=None): + def assertExpected(self, output, subname=None, prec=None, strip_suffix=None): r""" Test that a python value matches the recorded contents of a file derived from the name of this test and subname. The value must be @@ -96,16 +96,24 @@ def assertExpected(self, output, subname=None, prec=None): If you call this multiple times in a single function, you must give a unique subname each time. + + strip_suffix allows different tests that expect similar numerics, e.g. + "test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data. + test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass + strip_suffix="_cpu", and they would both use a data file name based on + "test_xyz". """ - def remove_prefix(text, prefix): + def remove_prefix_suffix(text, prefix, suffix): if text.startswith(prefix): - return text[len(prefix):] + text = text[len(prefix):] + if suffix is not None and text.endswith(suffix): + text = text[:len(text) - len(suffix)] return text # NB: we take __file__ from the module that defined the test # class, so we place the expect directory where the test script # lives, NOT where test/common_utils.py lives. module_id = self.__class__.__module__ - munged_id = remove_prefix(self.id(), module_id + ".") + munged_id = remove_prefix_suffix(self.id(), module_id + ".", strip_suffix) test_file = os.path.realpath(sys.modules[module_id].__file__) expected_file = os.path.join(os.path.dirname(test_file), "expect", diff --git a/test/test_models.py b/test/test_models.py index 1cee7a90003..faa14f8250e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -74,6 +74,26 @@ def get_available_video_models(): } +# The following models exhibit flaky numerics under autocast in _test_*_model harnesses. +# This may be caused by the harness environment (e.g. num classes, input initialization +# via torch.rand), and does not prove autocast is unsuitable when training with real data +# (autocast has been used successfully with real data for some of these models). +# TODO: investigate why autocast numerics are flaky in the harnesses. +# +# For the following models, _test_*_model harnesses skip numerical checks on outputs when +# trying autocast. However, they still try an autocasted forward pass, so they still ensure +# autocast coverage suffices to prevent dtype errors in each model. +autocast_flaky_numerics = ( + "fasterrcnn_resnet50_fpn", + "inception_v3", + "keypointrcnn_resnet50_fpn", + "maskrcnn_resnet50_fpn", + "resnet101", + "resnet152", + "wide_resnet101_2", +) + + class ModelTester(TestCase): def checkModule(self, model, name, args): if name not in script_test_models: @@ -81,65 +101,87 @@ def checkModule(self, model, name, args): unwrapper = script_test_models[name].get('unwrapper', None) return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False) - def _test_classification_model(self, name, input_shape): + def _test_classification_model(self, name, input_shape, dev): set_rng_seed(0) # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.__dict__[name](num_classes=50) - model.eval() - x = torch.rand(input_shape) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) out = model(x) - self.assertExpected(out, prec=0.1) + self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev) self.assertEqual(out.shape[-1], 50) self.checkModule(model, name, (x,)) - def _test_segmentation_model(self, name): + if dev == "cuda": + with torch.cuda.amp.autocast(): + out = model(x) + # See autocast_flaky_numerics comment at top of file. + if name not in autocast_flaky_numerics: + self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev) + self.assertEqual(out.shape[-1], 50) + + def _test_segmentation_model(self, name, dev): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False) - model.eval() + model.eval().to(device=dev) input_shape = (1, 3, 300, 300) - x = torch.rand(input_shape) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) out = model(x) self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) self.checkModule(model, name, (x,)) - def _test_detection_model(self, name): + if dev == "cuda": + with torch.cuda.amp.autocast(): + out = model(x) + self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + + def _test_detection_model(self, name, dev): set_rng_seed(0) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) - model.eval() + model.eval().to(device=dev) input_shape = (3, 300, 300) - x = torch.rand(input_shape) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) model_input = [x] out = model(model_input) self.assertIs(model_input[0], x) - self.assertEqual(len(out), 1) - def subsample_tensor(tensor): - num_elems = tensor.numel() - num_samples = 20 - if num_elems <= num_samples: - return tensor - - flat_tensor = tensor.flatten() - ith_index = num_elems // num_samples - return flat_tensor[ith_index - 1::ith_index] - - def compute_mean_std(tensor): - # can't compute mean of integral tensor - tensor = tensor.to(torch.double) - mean = torch.mean(tensor) - std = torch.std(tensor) - return {"mean": mean, "std": std} - - # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now - # compare results with mean and std - if name == "maskrcnn_resnet50_fpn": - test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) - # mean values are small, use large prec - self.assertExpected(test_value, prec=.01) - else: - self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), prec=0.01) + def check_out(out): + self.assertEqual(len(out), 1) + + def subsample_tensor(tensor): + num_elems = tensor.numel() + num_samples = 20 + if num_elems <= num_samples: + return tensor + + flat_tensor = tensor.flatten() + ith_index = num_elems // num_samples + return flat_tensor[ith_index - 1::ith_index] + + def compute_mean_std(tensor): + # can't compute mean of integral tensor + tensor = tensor.to(torch.double) + mean = torch.mean(tensor) + std = torch.std(tensor) + return {"mean": mean, "std": std} + + # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now + # compare results with mean and std + if name == "maskrcnn_resnet50_fpn": + test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) + # mean values are small, use large prec + self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev) + else: + self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), + prec=0.01, + strip_suffix="_" + dev) + + check_out(out) scripted_model = torch.jit.script(model) scripted_model.eval() @@ -156,6 +198,13 @@ def compute_mean_std(tensor): # self.check_script(model, name) self.checkModule(model, name, ([x],)) + if dev == "cuda": + with torch.cuda.amp.autocast(): + out = model(model_input) + # See autocast_flaky_numerics comment at top of file. + if name not in autocast_flaky_numerics: + check_out(out) + def _test_detection_model_validation(self, name): set_rng_seed(0) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) @@ -179,18 +228,24 @@ def _test_detection_model_validation(self, name): targets = [{'boxes': boxes}] self.assertRaises(ValueError, model, x, targets=targets) - def _test_video_model(self, name): + def _test_video_model(self, name, dev): # the default input shape is # bs * num_channels * clip_len * h *w input_shape = (1, 3, 4, 112, 112) # test both basicblock and Bottleneck model = models.video.__dict__[name](num_classes=50) - model.eval() - x = torch.rand(input_shape) + model.eval().to(device=dev) + # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests + x = torch.rand(input_shape).to(device=dev) out = model(x) self.checkModule(model, name, (x,)) self.assertEqual(out.shape[-1], 50) + if dev == "cuda": + with torch.cuda.amp.autocast(): + out = model(x) + self.assertEqual(out.shape[-1], 50) + def _make_sliced_model(self, model, stop_layer): layers = OrderedDict() for name, layer in model.named_children(): @@ -272,6 +327,12 @@ def test_googlenet_eval(self): @unittest.skipIf(not torch.cuda.is_available(), 'needs GPU') def test_fasterrcnn_switch_devices(self): + def checkOut(out): + self.assertEqual(len(out), 1) + self.assertTrue("boxes" in out[0]) + self.assertTrue("scores" in out[0]) + self.assertTrue("labels" in out[0]) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model.cuda() model.eval() @@ -280,17 +341,20 @@ def test_fasterrcnn_switch_devices(self): model_input = [x] out = model(model_input) self.assertIs(model_input[0], x) - self.assertEqual(len(out), 1) - self.assertTrue("boxes" in out[0]) - self.assertTrue("scores" in out[0]) - self.assertTrue("labels" in out[0]) + + checkOut(out) + + with torch.cuda.amp.autocast(): + out = model(model_input) + + checkOut(out) + # now switch to cpu and make sure it works model.cpu() x = x.cpu() out_cpu = model([x]) - self.assertTrue("boxes" in out_cpu[0]) - self.assertTrue("scores" in out_cpu[0]) - self.assertTrue("labels" in out_cpu[0]) + + checkOut(out_cpu) def test_generalizedrcnn_transform_repr(self): @@ -312,34 +376,40 @@ def test_generalizedrcnn_transform_repr(self): self.assertEqual(t.__repr__(), expected_string) +_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + + for model_name in get_available_classification_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - input_shape = (1, 3, 224, 224) - if model_name in ['inception_v3']: - input_shape = (1, 3, 299, 299) - self._test_classification_model(model_name, input_shape) + for dev in _devs: + # for-loop bodies don't define scopes, so we have to save the variables + # we want to close over in some way + def do_test(self, model_name=model_name, dev=dev): + input_shape = (1, 3, 224, 224) + if model_name in ['inception_v3']: + input_shape = (1, 3, 299, 299) + self._test_classification_model(model_name, input_shape, dev) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) for model_name in get_available_segmentation_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - self._test_segmentation_model(model_name) + for dev in _devs: + # for-loop bodies don't define scopes, so we have to save the variables + # we want to close over in some way + def do_test(self, model_name=model_name, dev=dev): + self._test_segmentation_model(model_name, dev) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) for model_name in get_available_detection_models(): - # for-loop bodies don't define scopes, so we have to save the variables - # we want to close over in some way - def do_test(self, model_name=model_name): - self._test_detection_model(model_name) + for dev in _devs: + # for-loop bodies don't define scopes, so we have to save the variables + # we want to close over in some way + def do_test(self, model_name=model_name, dev=dev): + self._test_detection_model(model_name, dev) - setattr(ModelTester, "test_" + model_name, do_test) + setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) def do_validation_test(self, model_name=model_name): self._test_detection_model_validation(model_name) @@ -348,11 +418,11 @@ def do_validation_test(self, model_name=model_name): for model_name in get_available_video_models(): + for dev in _devs: + def do_test(self, model_name=model_name, dev=dev): + self._test_video_model(model_name, dev) - def do_test(self, model_name=model_name): - self._test_video_model(model_name) - - setattr(ModelTester, "test_" + model_name, do_test) + setattr(ModelTester, "test_" + model_name + "_" + dev, do_test) if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 2e3107f8d7e..564d5d54559 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -52,25 +52,30 @@ def _test_backward(self, device, contiguous): class RoIOpTester(OpTester): - def _test_forward(self, device, contiguous): + def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None): + x_dtype = self.dtype if x_dtype is None else x_dtype + rois_dtype = self.dtype if rois_dtype is None else rois_dtype pool_size = 5 # n_channels % (pool_size ** 2) == 0 required for PS opeartions. n_channels = 2 * (pool_size ** 2) - x = torch.rand(2, n_channels, 10, 10, dtype=self.dtype, device=device) + x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], - dtype=self.dtype, device=device) + dtype=rois_dtype, device=device) pool_h, pool_w = pool_size, pool_size y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1) + # the following should be true whether we're running an autocast test or not. + self.assertTrue(y.dtype == x.dtype) gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype) - self.assertTrue(torch.allclose(gt_y, y)) + tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 + self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol)) def _test_backward(self, device, contiguous): pool_size = 2 @@ -290,6 +295,13 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r def _test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_roi_align_autocast(self): + for x_dtype in (torch.float, torch.half): + for rois_dtype in (torch.float, torch.half): + with torch.cuda.amp.autocast(): + self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) + class PSRoIAlignTester(RoIOpTester, unittest.TestCase): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): diff --git a/torchvision/csrc/ROIAlign.h b/torchvision/csrc/ROIAlign.h index 7a856f34d63..a7cbe954a4d 100644 --- a/torchvision/csrc/ROIAlign.h +++ b/torchvision/csrc/ROIAlign.h @@ -3,6 +3,7 @@ #include "cpu/vision_cpu.h" #ifdef WITH_CUDA +#include "autocast.h" #include "cuda/vision_cuda.h" #endif #ifdef WITH_HIP @@ -11,6 +12,7 @@ // TODO: put this stuff in torchvision namespace +// roi_align dispatch nexus at::Tensor roi_align( const at::Tensor& input, // Input feature map. const at::Tensor& rois, // List of ROIs to pool over. @@ -35,6 +37,28 @@ at::Tensor roi_align( aligned); } +#ifdef WITH_CUDA +at::Tensor ROIAlign_autocast( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, + const bool aligned) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return roi_align( + autocast::_cast(at::kFloat, input), + autocast::_cast(at::kFloat, rois), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned) + .to(input.scalar_type()); +} +#endif + at::Tensor _roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, diff --git a/torchvision/csrc/autocast.h b/torchvision/csrc/autocast.h new file mode 100644 index 00000000000..93c079fb1c5 --- /dev/null +++ b/torchvision/csrc/autocast.h @@ -0,0 +1,28 @@ +#pragma once + +#ifdef WITH_CUDA +namespace autocast { + +inline bool is_eligible(const at::Tensor& arg) { + return ( + arg.is_cuda() && arg.is_floating_point() && + (arg.scalar_type() != at::kDouble)); +} + +// Overload to catch Tensor args +inline at::Tensor _cast(at::ScalarType to_type, const at::Tensor& arg) { + if (is_eligible(arg) && (arg.scalar_type() != to_type)) { + return arg.to(to_type); + } else { + return arg; + } +} + +// Template to catch non-Tensor args +template +inline T _cast(at::ScalarType to_type, T arg) { + return arg; +} + +} // namespace autocast +#endif diff --git a/torchvision/csrc/cpu/nms_cpu.cpp b/torchvision/csrc/cpu/nms_cpu.cpp index 14c3b8b4f16..753a9c9e362 100644 --- a/torchvision/csrc/cpu/nms_cpu.cpp +++ b/torchvision/csrc/cpu/nms_cpu.cpp @@ -4,7 +4,7 @@ template at::Tensor nms_cpu_kernel( const at::Tensor& dets, const at::Tensor& scores, - const float iou_threshold) { + const double iou_threshold) { AT_ASSERTM(!dets.is_cuda(), "dets must be a CPU tensor"); AT_ASSERTM(!scores.is_cuda(), "scores must be a CPU tensor"); AT_ASSERTM( @@ -72,7 +72,26 @@ at::Tensor nms_cpu_kernel( at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, - const float iou_threshold) { + const double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + auto result = at::empty({0}, dets.options()); AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] { diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 64aa1ae2119..6b68b356225 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -85,7 +85,7 @@ at::Tensor PSROIAlign_backward_cpu( at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, - const float iou_threshold); + const double iou_threshold); at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input, diff --git a/torchvision/csrc/cuda/nms_cuda.cu b/torchvision/csrc/cuda/nms_cuda.cu index 2c519c4499d..f9c39541174 100644 --- a/torchvision/csrc/cuda/nms_cuda.cu +++ b/torchvision/csrc/cuda/nms_cuda.cu @@ -1,6 +1,11 @@ #include #include + +#if defined(WITH_CUDA) #include +#elif defined(WITH_HIP) +#include +#endif #include "cuda_helpers.h" @@ -70,10 +75,40 @@ __global__ void nms_kernel( at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, - float iou_threshold) { + const double iou_threshold) { AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor"); AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor"); + + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)) + +#if defined(WITH_CUDA) at::cuda::CUDAGuard device_guard(dets.device()); +#elif defined(WITH_HIP) + at::cuda::HIPGuard device_guard(dets.device()); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto dets_sorted = dets.index_select(0, order_t).contiguous(); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 834ba51a4cf..ef53d0c08b4 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -1,9 +1,4 @@ #pragma once -#if defined(WITH_CUDA) -#include -#elif defined(WITH_HIP) -#include -#endif #include at::Tensor ROIAlign_forward_cuda( @@ -90,7 +85,7 @@ at::Tensor PSROIAlign_backward_cuda( at::Tensor nms_cuda( const at::Tensor& dets, const at::Tensor& scores, - const float iou_threshold); + const double iou_threshold); at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input, diff --git a/torchvision/csrc/nms.h b/torchvision/csrc/nms.h index 3c2faba8353..6bbd3e0bc65 100644 --- a/torchvision/csrc/nms.h +++ b/torchvision/csrc/nms.h @@ -2,52 +2,33 @@ #include "cpu/vision_cpu.h" #ifdef WITH_CUDA +#include "autocast.h" #include "cuda/vision_cuda.h" #endif #ifdef WITH_HIP #include "hip/vision_cuda.h" #endif +// nms dispatch nexus at::Tensor nms( const at::Tensor& dets, const at::Tensor& scores, const double iou_threshold) { - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)); - if (dets.is_cuda()) { -#if defined(WITH_CUDA) - if (dets.numel() == 0) { - at::cuda::CUDAGuard device_guard(dets.device()); - return at::empty({0}, dets.options().dtype(at::kLong)); - } - return nms_cuda(dets, scores, iou_threshold); -#elif defined(WITH_HIP) - if (dets.numel() == 0) { - at::cuda::HIPGuard device_guard(dets.device()); - return at::empty({0}, dets.options().dtype(at::kLong)); - } - return nms_cuda(dets, scores, iou_threshold); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} - at::Tensor result = nms_cpu(dets, scores, iou_threshold); - return result; +#ifdef WITH_CUDA +at::Tensor nms_autocast( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return nms( + autocast::_cast(at::kFloat, dets), + autocast::_cast(at::kFloat, scores), + iou_threshold); } +#endif diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 7f56bdb51a1..aa2ec26bfef 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -43,7 +43,7 @@ int64_t _cuda_version() { } TORCH_LIBRARY(torchvision, m) { - m.def("nms", &nms); + m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def( "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); m.def( @@ -59,6 +59,7 @@ TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("roi_align", ROIAlign_forward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu); + m.impl("nms", nms_cpu); } // TODO: Place this in a hypothetical separate torchvision_cuda library @@ -66,6 +67,15 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("roi_align", ROIAlign_forward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda); + m.impl("nms", nms_cuda); +} +#endif + +// Autocast only needs to wrap forward pass ops. +#if defined(WITH_CUDA) +TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { + m.impl("roi_align", ROIAlign_autocast); + m.impl("nms", nms_autocast); } #endif diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index 06bbc86a93c..714acfa1def 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -224,7 +224,14 @@ def forward(self, x, boxes, image_shapes): if torchvision._is_tracing(): tracing_results.append(result_idx_in_level.to(dtype)) else: - result[idx_in_level] = result_idx_in_level + # result and result_idx_in_level's dtypes are based on dtypes of different + # elements in x_filtered. x_filtered contains tensors output by different + # layers. When autocast is active, it may choose different dtypes for + # different layers' outputs. Therefore, we defensively match result's dtype + # before copying elements from result_idx_in_level in the following op. + # We need to cast manually (can't rely on autocast to cast for us) because + # the op acts on result in-place, and autocast only affects out-of-place ops. + result[idx_in_level] = result_idx_in_level.to(result.dtype) if torchvision._is_tracing(): result = _onnx_merge_levels(levels, tracing_results)