Skip to content

Commit

Permalink
[numpy] torch.exp: promote integer inputs to float (#50093)
Browse files Browse the repository at this point in the history
Summary:
Reference: #42515

Pull Request resolved: #50093

Reviewed By: H-Huang

Differential Revision: D25803549

Pulled By: mruberry

fbshipit-source-id: e6f245b5e728f2dca6072f8c359f03dff63aa14d
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Jan 8, 2021
1 parent fc2ead0 commit 9f832c8
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 11 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ Tensor& ceil_out(Tensor& result, const Tensor& self) {
Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); }
Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); }

Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, exp_stub); }
Tensor exp(const Tensor& self) { return unary_op_impl(self, at::exp_out); }
Tensor& exp_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp_stub); }
Tensor exp(const Tensor& self) { return unary_op_impl_float(self, exp_stub); }
Tensor& exp_(Tensor& self) { return unary_op_impl_(self, at::exp_out); }

Tensor& exp2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, exp2_stub); }
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) {
}

void exp_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "exp_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::exp(a);
});
Expand Down
3 changes: 0 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6909,9 +6909,6 @@ def inner(self, device, dtype):
('atanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('exp', '', _small_3d, lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('exp', 'small', lambda t, d: _small_3d(t, d).clamp(-1, 1),
lambda t, d: [], 1e-2, 5e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('rad2deg', '', _small_3d, lambda t, d: [], 1e-1, 1e-0, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('deg2rad', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
Expand Down
1 change: 0 additions & 1 deletion test/test_unary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,7 +1702,6 @@ def _medium_2d(dtype, device):

# TODO: all these should be replaced with OpInfos
torch_op_tests = [
_TorchMathTestMeta('exp'),
_TorchMathTestMeta('floor'),
_TorchMathTestMeta('ceil'),
_TorchMathTestMeta('rad2deg'),
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::exp: {
return computeOneOperand(
"aten_exp", v, [](const ExprHandle& a) { return exp(a); });
return computeOneOperand("aten_exp", v, [](const ExprHandle& a) {
return exp(promoteIntegerToDefaultType(a));
});
} break;

case aten::expm1: {
Expand Down
16 changes: 14 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,20 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.float16]),
)),
UnaryUfuncInfo('exp',
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
dtypes=all_types_and_complex_and(torch.bool, torch.half),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]),
# Reference: https://github.com/pytorch/pytorch/issues/48010
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
),
assert_autodiffed=True,
promotes_integers_to_float=True),
SpectralFuncInfo('fft.fft',
aten_name='fft_fft',
ref=np.fft.fft,
Expand Down Expand Up @@ -1602,8 +1616,6 @@ def method_tests():
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)),
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
('exp', (S, S, S), NO_ARGS, '', (True,)),
('exp', (), NO_ARGS, 'scalar', (True,)),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, ''),
('logit', torch.randn(S, S, S).clamp(0.1, 0.9).requires_grad_(True), (0.2,), 'eps'),
('logit', uniform_scalar().clamp(0.1, 0.9).requires_grad_(True), NO_ARGS, 'scalar'),
Expand Down

0 comments on commit 9f832c8

Please sign in to comment.