From 204675ab97f9166e3746a50027f5fe1d1ff9d1d5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 12 Nov 2020 22:18:52 -0800 Subject: [PATCH] CUDA BFloat16 Dropout (#45005) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45005 Reviewed By: mruberry Differential Revision: D24934761 Pulled By: ngimel fbshipit-source-id: 8f615b97fb93dcd04a46e1d8eeb817ade5082990 --- aten/src/ATen/native/cuda/Dropout.cu | 164 +++++++++++++-------------- test/test_nn.py | 8 +- 2 files changed, 84 insertions(+), 88 deletions(-) diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 79736677debc..23f8834e73de 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -210,48 +210,77 @@ inline void launcher( self.scalar_type(), "fused_dropout", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(p); - auto self_info = - cuda::detail::getTensorInfo(self); - auto ret_info = - cuda::detail::getTensorInfo(ret); - auto mask_info = - cuda::detail::getTensorInfo(mask); - self_info.collapseDims(); - ret_info.collapseDims(); - mask_info.collapseDims(); // ret and mask are collapsed to 1d - // contiguous tensor + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = + cuda::detail::getTensorInfo(self); + auto ret_info = + cuda::detail::getTensorInfo(ret); + auto mask_info = + cuda::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); // ret and mask are collapsed to 1d + // contiguous tensor - int vec_size = get_vector_size(self, ret, mask); + int vec_size = get_vector_size(self, ret, mask); - if (vec_size > 1) { - switch (vec_size) { - case 4: - fused_dropout_kernel_vec< - scalar_t, - accscalar_t, - index_type, - 1, - 4> - <<>>( - self_info, - ret_info, - mask_info, - nelem, - pa, - rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - case 2: - fused_dropout_kernel_vec< - scalar_t, - accscalar_t, - index_type, - 1, - 2> - <<>>( + if (vec_size > 1) { + switch (vec_size) { + case 4: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 4> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 2> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + } else { + switch (self_info.dims) { + case 1: + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: + if (!self.is_contiguous() && ret.is_contiguous() && + mask.is_contiguous()) { + fused_dropout_kernel + <<>>( self_info, ret_info, mask_info, @@ -259,13 +288,12 @@ inline void launcher( pa, rng_engine_inputs); TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - } - } else { - switch (self_info.dims) { - case 1: - fused_dropout_kernel - <<>>( + } else { + fused_dropout_kernel + <<>>( self_info, ret_info, mask_info, @@ -273,39 +301,9 @@ inline void launcher( pa, rng_engine_inputs); TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - break; - default: - if (!self.is_contiguous() && ret.is_contiguous() && - mask.is_contiguous()) { - fused_dropout_kernel - <<>>( - self_info, - ret_info, - mask_info, - nelem, - pa, - rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - fused_dropout_kernel - <<>>( - self_info, - ret_info, - mask_info, - nelem, - pa, - rng_engine_inputs); - TORCH_CUDA_KERNEL_LAUNCH_CHECK(); - } - } + } } - }); + } }); } @@ -346,11 +344,9 @@ Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){ Tensor ret = at::empty_like(self, self.suggest_memory_format()); TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "masked_scale", [&] { - using accscalar_t = acc_type; - accscalar_t pa = (accscalar_t)(scale); - masked_scale_kernel(ret, self, mask, pa); - }); + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(scale); + masked_scale_kernel(ret, self, mask, pa); }); return ret; } diff --git a/test/test_nn.py b/test/test_nn.py index 020b206905d9..b2c6fd575570 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10196,7 +10196,7 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) def test_Dropout(self, device): - input = torch.Tensor(1000) + input = torch.empty(1000) self._test_dropout(nn.Dropout, device, input) self._test_dropout_discontiguous(nn.Dropout, device) @@ -10204,7 +10204,7 @@ def test_Dropout(self, device): self._test_dropout_stride_mean_preserve(nn.Dropout, device) - if self.device_type == 'cuda' and TEST_WITH_ROCM: + if self.device_type == 'cuda': input = input.bfloat16() self._test_dropout(nn.Dropout, device, input) @@ -10213,7 +10213,7 @@ def test_Dropout2d(self, device): w = random.randint(1, 5) h = random.randint(1, 5) num_features = 1000 - input = torch.Tensor(num_features, b, w, h) + input = torch.empty(num_features, b, w, h) self._test_dropout(nn.Dropout2d, device, input) self._test_dropout(nn.Dropout2d, device, input, memory_format=torch.channels_last) @@ -10226,7 +10226,7 @@ def test_Dropout3d(self, device): h = random.randint(1, 5) d = random.randint(1, 2) num_features = 1000 - input = torch.Tensor(num_features, b, d, w, h) + input = torch.empty(num_features, b, d, w, h) self._test_dropout(nn.Dropout3d, device, input) self._test_dropout_discontiguous(nn.Dropout3d, device)