Skip to content

Commit

Permalink
[pytorch] bfloat16 support in erfinv (#111257)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #111257

As title

Test Plan: CI

Differential Revision: D50280766

fbshipit-source-id: 80889994d9fb503bab0417463cae82e8ba804705
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 14, 2023
1 parent cff8bf4 commit a601fdb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -281,18 +281,19 @@ void erfc_kernel_cuda(TensorIteratorBase& iter) {
CONSTEXPR_EXCEPT_WIN_CUDA char erfinv_name[] = "erfinv_kernel";
void erfinv_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfinv_cuda", [&]() {
jitted_gpu_kernel</*name=*/erfinv_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, erfinv_string);
});
#else
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
});
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17259,7 +17259,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
torch.bfloat16: 1e-2,
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,
Expand Down

0 comments on commit a601fdb

Please sign in to comment.