Skip to content

Commit

Permalink
[numpy] torch.exp{2, m1}: promote integer inputs to float (#48926)
Browse files Browse the repository at this point in the history
Summary:
Reference: #42515

Pull Request resolved: #48926

Reviewed By: zhangguanheng66

Differential Revision: D25392344

Pulled By: mruberry

fbshipit-source-id: ddbabcfd58cc4c944153b1a224cc232efa022104
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Dec 10, 2020
1 parent 27f7d1c commit eb9516e
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 32 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,12 @@ Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(r
Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); }
Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); }

Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp2_stub); }
Tensor exp2(const Tensor& self) { return unary_op_impl(self, at::exp2_out); }
Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); }
Tensor exp2(const Tensor& self) { return unary_op_impl_float(self, exp2_stub); }
Tensor& exp2_(Tensor& self) { return unary_op_impl_(self, at::exp2_out); }

Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, expm1_stub); }
Tensor expm1(const Tensor& self) { return unary_op_impl(self, at::expm1_out); }
Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, expm1_stub); }
Tensor expm1(const Tensor& self) { return unary_op_impl_float(self, expm1_stub); }
Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); }

Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erf_stub); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ void exp_kernel_cuda(TensorIterator& iter) {
}

void exp2_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exp2_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "exp2_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::exp2(a);
});
});
}

void expm1_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "expm1_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "expm1_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::expm1(a);
});
Expand Down
3 changes: 0 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6905,9 +6905,6 @@ def inner(self, device, dtype):
('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('expm1', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types),
('expm1', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
Expand Down
1 change: 0 additions & 1 deletion test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,6 @@ def _medium_2d(dtype, device):
# TODO: all these should be replaced with OpInfos
torch_op_tests = [
_TorchMathTestMeta('exp'),
_TorchMathTestMeta('expm1'),
_TorchMathTestMeta('floor'),
_TorchMathTestMeta('ceil'),
_TorchMathTestMeta('rad2deg'),
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::expm1: {
return computeOneOperand(
"aten_expm1", v, [](const ExprHandle& a) { return expm1(a); });
return computeOneOperand("aten_expm1", v, [](const ExprHandle& a) {
return expm1(promoteIntegerToFloat(a));
});
} break;

case aten::erf: {
Expand Down
42 changes: 22 additions & 20 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,26 @@ def sample_inputs(self, device, dtype, requires_grad=False):
active_if=(IS_MACOS or IS_WINDOWS)),
)),
UnaryUfuncInfo('exp2',
ref=np.exp2,
dtypes=floating_types_and(torch.half),
dtypesIfCPU=None,
dtypesIfCUDA=None),
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
dtypes=all_types_and(torch.bool, torch.half),
dtypesIfCPU=all_types_and(torch.bool, torch.half),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
promotes_integers_to_float=True),
UnaryUfuncInfo('expm1',
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
dtypes=all_types_and(torch.bool, torch.half),
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
promotes_integers_to_float=True,
assert_autodiffed=True,
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.bfloat16]),
# RuntimeError: "isfinite" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cpu', dtypes=[torch.bfloat16]),
)),
UnaryUfuncInfo('nan_to_num',
ref=np.nan_to_num,
dtypes=all_types_and(torch.half, torch.bool),
Expand All @@ -785,25 +801,13 @@ def sample_inputs(self, device, dtype, requires_grad=False):
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
# RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype.
SkipInfo('TestGradients', 'test_fn_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_fn_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestCommon', 'test_variant_consistency_eager',
dtypes=[torch.cfloat, torch.cdouble]),
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.cfloat, torch.cdouble])),
promotes_integers_to_float=True,
handles_complex_extremals=False),
handles_complex_extremals=False,
test_complex_grad=False),
]

if TEST_SCIPY:
Expand Down Expand Up @@ -1124,8 +1128,6 @@ def method_tests():
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),
('exp', (), NO_ARGS, 'scalar', (True,)),
('expm1', (S, S, S), NO_ARGS, '', (True,)),
('expm1', (), NO_ARGS, 'scalar', (True,)),
('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''),
Expand Down

0 comments on commit eb9516e

Please sign in to comment.