From ca27f31d67537ac41532099833de389c970d555c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 3 Nov 2020 11:22:17 -0600 Subject: [PATCH 1/6] init int -> float support for sqrt --- aten/src/ATen/native/UnaryOps.cpp | 4 ++-- aten/src/ATen/native/cuda/UnaryOpsKernel.cu | 4 ++-- torch/csrc/jit/tensorexpr/kernel.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 1aebfda85da0..b09a473f15e7 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -378,8 +378,8 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r Tensor arctanh(const Tensor& self) { return self.atanh(); } Tensor& arctanh_(Tensor& self) { return self.atanh_(); } -Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); } -Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); } +Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); } +Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); } Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); } Tensor square(const Tensor& self) { return at::pow(self, 2); } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 4b1f0c1a6aa3..25dbf5a1a6ef 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -78,7 +78,7 @@ __host__ __device__ static inline c10::complex rsqrt_wrapper(c10::complex } void rsqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. return rsqrt_wrapper(a); @@ -87,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) { } void sqrt_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sqrt(a); }); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 4705760b868d..eb3ed96c0d00 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1129,7 +1129,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::sqrt: { return computeOneOperand( - "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); }); + "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(promoteIntegerToFloat(a)); }); } break; case aten::rsqrt: { From 688079997cc365c4fa16b5912a5f0c3427b30b7d Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 3 Nov 2020 11:22:33 -0600 Subject: [PATCH 2/6] add entry to unaryop db --- .../_internal/common_methods_invocations.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24ad7531700f..2c7fae971e88 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -419,7 +419,35 @@ def sample_inputs(self, device, dtype, requires_grad=False): ref=np.nan_to_num, dtypes=all_types_and(torch.half, torch.bool), dtypesIfCPU=None, - dtypesIfCUDA=None) + dtypesIfCUDA=None), + UnaryUfuncInfo('sqrt', + ref=np.sqrt, + domain=(0, float('inf')), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + decorators=(precisionOverride({torch.bfloat16: 7e-2}),), + skips=( + # Investigate flaky behaviour + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cpu', dtypes=[torch.bfloat16, torch.cfloat, torch.cdouble]), + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + device_type='cuda', dtypes=[torch.bfloat16]), + # RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. + SkipInfo('TestGradients', 'test_fn_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_fn_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_method_gradgrad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_grad', + dtypes=[torch.cdouble]), + SkipInfo('TestGradients', 'test_inplace_gradgrad', + dtypes=[torch.cdouble]),), + promotes_integers_to_float=True, + handles_complex_extremals=False), ] # Common operator groupings From a587945d755b8d1e3e320092ae5977a76d51e7dd Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 4 Nov 2020 11:49:30 -0600 Subject: [PATCH 3/6] add references to the issues --- torch/testing/_internal/common_methods_invocations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2c7fae971e88..5a3f0774dbaf 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -428,11 +428,12 @@ def sample_inputs(self, device, dtype, requires_grad=False): dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), decorators=(precisionOverride({torch.bfloat16: 7e-2}),), skips=( - # Investigate flaky behaviour + # Reference: https://github.com/pytorch/pytorch/issues/47358 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', - device_type='cpu', dtypes=[torch.bfloat16, torch.cfloat, torch.cdouble]), + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', - device_type='cuda', dtypes=[torch.bfloat16]), + dtypes=[torch.bfloat16]), # RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype. SkipInfo('TestGradients', 'test_fn_grad', dtypes=[torch.cdouble]), From 331f68bbc66df6eefc61e60694dd6a6904e00f99 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 4 Nov 2020 11:56:18 -0600 Subject: [PATCH 4/6] make clang-format happy --- torch/csrc/jit/tensorexpr/kernel.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index eb3ed96c0d00..fab24f07d6d0 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1128,8 +1128,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::sqrt: { - return computeOneOperand( - "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(promoteIntegerToFloat(a)); }); + return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { + return sqrt(promoteIntegerToFloat(a)); + }); } break; case aten::rsqrt: { From 44c4359e795bb5e317d85a4f59b7906d280313f1 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 6 Nov 2020 07:46:33 -0600 Subject: [PATCH 5/6] remove redundant test --- test/test_torch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 92f04549992a..b1b56a917477 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -20995,8 +20995,7 @@ def __init__(self, self.dtypes = dtypes self.replace_inf_with_nan = replace_inf_with_nan -torch_op_tests = [_TorchMathTestMeta('sqrt'), - _TorchMathTestMeta('erf', ref_backend='scipy'), +torch_op_tests = [_TorchMathTestMeta('erf', ref_backend='scipy'), _TorchMathTestMeta('erfc', ref_backend='scipy'), _TorchMathTestMeta('exp'), _TorchMathTestMeta('expm1'), From 91b9a3a6af2662c4a9d82e1ff008bdd15aab2453 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 10 Nov 2020 21:18:39 -0600 Subject: [PATCH 6/6] skip complex reference numerics only on macOS --- torch/testing/_internal/common_methods_invocations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 17b8bd166a53..acca65c96a35 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -447,7 +447,8 @@ def sample_inputs(self, device, dtype, requires_grad=False): skips=( # Reference: https://github.com/pytorch/pytorch/issues/47358 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', - device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_MACOS), # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', dtypes=[torch.bfloat16]),