Skip to content

Commit

Permalink
[numpy] torch.erfinv: promote integer inputs to float (#49155)
Browse files Browse the repository at this point in the history
Summary:
Reference: #42515

Pull Request resolved: #49155

Reviewed By: ngimel

Differential Revision: D25664234

Pulled By: mruberry

fbshipit-source-id: 630fd1d334567d78c8130236a67dda0f5ec02560
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 23, 2020
1 parent 4d61109 commit 3f4b98d
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 29 deletions.
5 changes: 4 additions & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_floa
Tensor erfc(const Tensor& self) { return unary_op_impl_float(self, erfc_stub); }
Tensor& erfc_(Tensor& self) { return unary_op_impl_(self, at::erfc_out); }

Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfinv_stub); }
Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); }
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }

Tensor& frac_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, frac_stub); }
Tensor frac(const Tensor& self) { return unary_op_impl(self, at::frac_out); }
Tensor& frac_(Tensor& self) { return unary_op_impl_(self, at::frac_out); }
Expand Down Expand Up @@ -683,7 +687,6 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA)

IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv)
IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)

DEFINE_DISPATCH(abs_stub);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void erfc_kernel_cuda(TensorIterator& iter) {
}

void erfinv_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfinv_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6937,14 +6937,12 @@
use_c10_dispatcher: full
variants: method
dispatch:
CPU: _erfinv__cpu
CUDA: _erfinv__cuda
CPU, CUDA: erfinv_

- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
dispatch:
CPU: _erfinv_out_cpu
CUDA: _erfinv_out_cuda
CPU, CUDA: erfinv_out

- func: i0(Tensor self) -> Tensor
use_c10_dispatcher: full
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6913,7 +6913,6 @@ def inner(self, device, dtype):
('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('erfinv', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
Expand Down
20 changes: 0 additions & 20 deletions test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,26 +766,6 @@ def test_ceil_out_mismatch(self, device):
b = torch.randn(1, device=device)
self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b))

# TODO: review with erfinv opinfo
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float, torch.double)
def test_erfinv(self, device, dtype):
# general testing. Narrow the range to avoid accuracy issues
input_values = torch.randn(4, 4, dtype=dtype, device=device).clamp(-0.3, 0.3)
self.assertEqual(input_values.erf().erfinv(), input_values)
# test inf
self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(),
torch.tensor([-inf, inf], dtype=dtype, device=device)))
# test nan
self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(),
torch.tensor([nan, nan], dtype=dtype, device=device))

if dtype == torch.double:
# double precision
a = torch.tensor([0.5, 0.8], dtype=torch.double, device=device).erfinv()
self.assertEqual(a[0].item(), 0.47693627620447, atol=1e-13, rtol=0)
self.assertEqual(a[1].item(), 0.90619380243682, atol=1e-13, rtol=0)

# TODO: opinfo hardshrink
@onlyCPU
@dtypes(torch.float, torch.double)
Expand Down
23 changes: 21 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_SLOW)

from distutils.version import LooseVersion

if TEST_SCIPY:
import scipy.special

Expand Down Expand Up @@ -1139,6 +1141,25 @@ def reference_sigmoid(x):
dtypes=[torch.bfloat16]),),
assert_autodiffed=True,
promotes_integers_to_float=True),
UnaryUfuncInfo('erfinv',
ref=scipy.special.erfinv,
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2,
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
promotes_integers_to_float=True,
domain=(-1, 1),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
active_if=LooseVersion(scipy.__version__) < "1.4.0"),
# RuntimeError: "pow" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.bfloat16]),
)
),
OpInfo('xlogy',
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
Expand Down Expand Up @@ -1412,8 +1433,6 @@ def method_tests():
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),
('exp', (), NO_ARGS, 'scalar', (True,)),
('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'),
('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'),
Expand Down

0 comments on commit 3f4b98d

Please sign in to comment.