diff --git a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu index cd62641a80d72..f259776c2fc3b 100644 --- a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu @@ -281,18 +281,19 @@ void erfc_kernel_cuda(TensorIteratorBase& iter) { CONSTEXPR_EXCEPT_WIN_CUDA char erfinv_name[] = "erfinv_kernel"; void erfinv_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfinv_cuda", [&]() { jitted_gpu_kernel(iter, erfinv_string); }); #else - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::erfinv(a); - }); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(ScalarType::Half, ScalarType::BFloat16, + iter.common_dtype(), "erfinv_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::erfinv(a); + }); + }); #endif } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c4ec0d454f2c0..ab722302ee255 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -17259,7 +17259,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): torch.bfloat16: 1e-2, torch.float32: 1e-4}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.half), + dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), supports_sparse_csr=True, supports_sparse_csc=True, supports_sparse_bsr=True,