Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.erfinv: promote integer inputs to float #49155

Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_floa
Tensor erfc(const Tensor& self) { return unary_op_impl_float(self, erfc_stub); }
Tensor& erfc_(Tensor& self) { return unary_op_impl_(self, at::erfc_out); }

Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfinv_stub); }
Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); }
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }

Tensor& frac_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, frac_stub); }
Tensor frac(const Tensor& self) { return unary_op_impl(self, at::frac_out); }
Tensor& frac_(Tensor& self) { return unary_op_impl_(self, at::frac_out); }
Expand Down Expand Up @@ -669,7 +673,6 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA)

IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv)
IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)

DEFINE_DISPATCH(abs_stub);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void erfc_kernel_cuda(TensorIterator& iter) {
}

void erfinv_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfinv_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "erfinv_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::erfinv(a);
});
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6494,13 +6494,11 @@
use_c10_dispatcher: full
variants: method
dispatch:
CPU: _erfinv__cpu
CUDA: _erfinv__cuda
CPU, CUDA: erfinv_

- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: _erfinv_out_cpu
CUDA: _erfinv_out_cuda
CPU, CUDA: erfinv_out

- func: i0(Tensor self) -> Tensor
use_c10_dispatcher: full
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6914,7 +6914,6 @@ def inner(self, device, dtype):
('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('erfinv', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
Expand Down
20 changes: 0 additions & 20 deletions test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,26 +764,6 @@ def test_ceil_out_mismatch(self, device):
b = torch.randn(1, device=device)
self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b))

# TODO: review with erfinv opinfo
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float, torch.double)
def test_erfinv(self, device, dtype):
# general testing. Narrow the range to avoid accuracy issues
input_values = torch.randn(4, 4, dtype=dtype, device=device).clamp(-0.3, 0.3)
self.assertEqual(input_values.erf().erfinv(), input_values)
# test inf
self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(),
torch.tensor([-inf, inf], dtype=dtype, device=device)))
# test nan
self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(),
torch.tensor([nan, nan], dtype=dtype, device=device))

if dtype == torch.double:
# double precision
a = torch.tensor([0.5, 0.8], dtype=torch.double, device=device).erfinv()
self.assertEqual(a[0].item(), 0.47693627620447, atol=1e-13, rtol=0)
self.assertEqual(a[1].item(), 0.90619380243682, atol=1e-13, rtol=0)

# TODO: opinfo hardshrink
@onlyCPU
@dtypes(torch.float, torch.double)
Expand Down
21 changes: 18 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value, set_rng_seed,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_SLOW)
torch_to_numpy_dtype_dict, TEST_WITH_SLOW, version_atleast)

if TEST_SCIPY:
import scipy.special
Expand Down Expand Up @@ -863,6 +863,23 @@ def reference_sigmoid(x):
dtypes=[torch.bfloat16]),),
assert_autodiffed=True,
promotes_integers_to_float=True),
UnaryUfuncInfo('erfinv',
ref=scipy.special.erfinv,
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2,
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
promotes_integers_to_float=True,
domain=(-1, 1),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
active_if=not version_atleast(scipy.__version__, required_version="1.4.0")),
# RuntimeError: "pow" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
dtypes=[torch.bfloat16]),),)
]
op_db = op_db + op_db_scipy_reference

Expand Down Expand Up @@ -1128,8 +1145,6 @@ def method_tests():
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),
('exp', (), NO_ARGS, 'scalar', (True,)),
('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'),
('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'),
Expand Down
11 changes: 11 additions & 0 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,3 +1883,14 @@ def _assertGradAndGradgradChecks(test_case, apply_fn, inputs):
torch.double: 1e-5,
torch.half: 1e-2,
torch.bfloat16: 1e-1}

# Returns True if version is at least required version, False otherwise
def version_atleast(version, required_version):
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
def versiontuple(v):
"""
Helper function to compare versions.
Reference: https://stackoverflow.com/a/11887825/5602957
"""
return tuple(map(int, (v.split("."))))

return versiontuple(version) >= versiontuple(required_version)