Skip to content

Commit

Permalink
CUDA BFloat16 Dropout (pytorch#45005)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#45005

Reviewed By: mruberry

Differential Revision: D24934761

Pulled By: ngimel

fbshipit-source-id: 8f615b97fb93dcd04a46e1d8eeb817ade5082990
  • Loading branch information
zasdfgbnm authored and tugsbayasgalan committed Nov 16, 2020
1 parent dfdffb0 commit 204675a
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 88 deletions.
164 changes: 80 additions & 84 deletions aten/src/ATen/native/cuda/Dropout.cu
Expand Up @@ -210,102 +210,100 @@ inline void launcher(
self.scalar_type(),
"fused_dropout",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "fused_dropout", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(p);
auto self_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(self);
auto ret_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
auto mask_info =
cuda::detail::getTensorInfo<uint8_t, index_type>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); // ret and mask are collapsed to 1d
// contiguous tensor
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(p);
auto self_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(self);
auto ret_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
auto mask_info =
cuda::detail::getTensorInfo<uint8_t, index_type>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); // ret and mask are collapsed to 1d
// contiguous tensor

int vec_size = get_vector_size<scalar_t>(self, ret, mask);
int vec_size = get_vector_size<scalar_t>(self, ret, mask);

if (vec_size > 1) {
switch (vec_size) {
case 4:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
4>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
2>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
if (vec_size > 1) {
switch (vec_size) {
case 4:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
4>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
2>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
} else {
switch (self_info.dims) {
case 1:
fused_dropout_kernel<scalar_t, accscalar_t, index_type, 1>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
if (!self.is_contiguous() && ret.is_contiguous() &&
mask.is_contiguous()) {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1, 1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
} else {
switch (self_info.dims) {
case 1:
fused_dropout_kernel<scalar_t, accscalar_t, index_type, 1>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
} else {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
if (!self.is_contiguous() && ret.is_contiguous() &&
mask.is_contiguous()) {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1, 1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
} else {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
TORCH_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
}
});
}
});
}
Expand Down Expand Up @@ -346,11 +344,9 @@ Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
Tensor ret = at::empty_like(self, self.suggest_memory_format());
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "masked_scale", [&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(scale);
masked_scale_kernel<scalar_t>(ret, self, mask, pa);
});
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(scale);
masked_scale_kernel<scalar_t>(ret, self, mask, pa);
});
return ret;
}
Expand Down
8 changes: 4 additions & 4 deletions test/test_nn.py
Expand Up @@ -10196,15 +10196,15 @@ def test_affine_3d_rotateRandom(self, device):
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))

def test_Dropout(self, device):
input = torch.Tensor(1000)
input = torch.empty(1000)
self._test_dropout(nn.Dropout, device, input)

self._test_dropout_discontiguous(nn.Dropout, device)
self._test_dropout_discontiguous(nn.Dropout, device, memory_format=torch.channels_last)

self._test_dropout_stride_mean_preserve(nn.Dropout, device)

if self.device_type == 'cuda' and TEST_WITH_ROCM:
if self.device_type == 'cuda':
input = input.bfloat16()
self._test_dropout(nn.Dropout, device, input)

Expand All @@ -10213,7 +10213,7 @@ def test_Dropout2d(self, device):
w = random.randint(1, 5)
h = random.randint(1, 5)
num_features = 1000
input = torch.Tensor(num_features, b, w, h)
input = torch.empty(num_features, b, w, h)
self._test_dropout(nn.Dropout2d, device, input)
self._test_dropout(nn.Dropout2d, device, input, memory_format=torch.channels_last)

Expand All @@ -10226,7 +10226,7 @@ def test_Dropout3d(self, device):
h = random.randint(1, 5)
d = random.randint(1, 2)
num_features = 1000
input = torch.Tensor(num_features, b, d, w, h)
input = torch.empty(num_features, b, d, w, h)
self._test_dropout(nn.Dropout3d, device, input)

self._test_dropout_discontiguous(nn.Dropout3d, device)
Expand Down

0 comments on commit 204675a

Please sign in to comment.