From 3649a2c170c45653d2aa1267d48beb867914b039 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 12 Nov 2020 16:14:11 -0800 Subject: [PATCH] [numpy] `torch.sqrt` : promote integer inputs to float (#47293) Summary: Reference https://github.com/pytorch/pytorch/issues/42515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47293 Reviewed By: malfet Differential Revision: D24855994 Pulled By: mruberry fbshipit-source-id: 1e6752f2eeba6d638dea0bdea0c650cf722718c9 --- aten/src/ATen/native/UnaryOps.cpp | 4 +-- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 4 +-- test/test_torch.py | 3 +- torch/csrc/jit/tensorexpr/kernel.cpp | 5 +-- .../_internal/common_methods_invocations.py | 32 ++++++++++++++++++- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 7d57651005d5..8e01aff472ff 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -378,8 +378,8 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r Tensor arctanh(const Tensor& self) { return self.atanh(); } Tensor& arctanh_(Tensor& self) { return self.atanh_(); } -Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); } -Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); } +Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); } +Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); } Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); } Tensor square(const Tensor& self) { return at::pow(self, 2); } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 4b1f0c1a6aa3..25dbf5a1a6ef 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -78,7 +78,7 @@ __host__ __device__ static inline c10::complex rsqrt_wrapper(c10::complex } void rsqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. return rsqrt_wrapper(a); @@ -87,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) { } void sqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sqrt(a); }); diff --git a/test/test_torch.py b/test/test_torch.py index 6bcbd5582dc8..ddde396c011b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -20921,8 +20921,7 @@ def __init__(self, self.dtypes = dtypes self.replace_inf_with_nan = replace_inf_with_nan -torch_op_tests = [_TorchMathTestMeta('sqrt'), - _TorchMathTestMeta('erf', ref_backend='scipy'), +torch_op_tests = [_TorchMathTestMeta('erf', ref_backend='scipy'), _TorchMathTestMeta('erfc', ref_backend='scipy'), _TorchMathTestMeta('exp'), _TorchMathTestMeta('expm1'), diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 24bfedc92841..83ecb69774ae 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1146,8 +1146,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::sqrt: { - return computeOneOperand( - "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); }); + return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { + return sqrt(promoteIntegerToFloat(a)); + }); } break; case aten::rsqrt: { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e62d62eb12bf..8918d3ea8c1f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -436,7 +436,37 @@ def sample_inputs(self, device, dtype, requires_grad=False): ref=np.nan_to_num, dtypes=all_types_and(torch.half, torch.bool), dtypesIfCPU=None, - dtypesIfCUDA=None) + dtypesIfCUDA=None), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 7e-2}),), + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/47358 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_MACOS), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.bfloat16]), + # RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. + SkipInfo('TestGradients', 'test_fn_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_fn_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_gradgrad', + dtypes=[torch.cdouble]),), + promotes_integers_to_float=True, + handles_complex_extremals=False), ] # Common operator groupings