Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.sqrt : promote integer inputs to float #47293

4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -378,8 +378,8 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r
Tensor arctanh(const Tensor& self) { return self.atanh(); }
Tensor& arctanh_(Tensor& self) { return self.atanh_(); }

Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); }
Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); }
Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); }
Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); }
Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); }

Tensor square(const Tensor& self) { return at::pow(self, 2); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Expand Up @@ -78,7 +78,7 @@ __host__ __device__ static inline c10::complex<T> rsqrt_wrapper(c10::complex<T>
}

void rsqrt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
return rsqrt_wrapper(a);
Expand All @@ -87,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) {
}

void sqrt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::sqrt(a);
});
Expand Down
3 changes: 1 addition & 2 deletions test/test_torch.py
Expand Up @@ -21216,8 +21216,7 @@ def __init__(self,
self.dtypes = dtypes
self.replace_inf_with_nan = replace_inf_with_nan

torch_op_tests = [_TorchMathTestMeta('sqrt'),
_TorchMathTestMeta('erf', ref_backend='scipy'),
torch_op_tests = [_TorchMathTestMeta('erf', ref_backend='scipy'),
_TorchMathTestMeta('erfc', ref_backend='scipy'),
_TorchMathTestMeta('exp'),
_TorchMathTestMeta('expm1'),
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -1165,8 +1165,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::sqrt: {
return computeOneOperand(
"aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); });
return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) {
return sqrt(promoteIntegerToFloat(a));
});
} break;

case aten::rsqrt: {
Expand Down
32 changes: 31 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -436,7 +436,37 @@ def sample_inputs(self, device, dtype, requires_grad=False):
ref=np.nan_to_num,
dtypes=all_types_and(torch.half, torch.bool),
dtypesIfCPU=None,
dtypesIfCUDA=None)
dtypesIfCUDA=None),
UnaryUfuncInfo('sqrt',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see another OpInfo! With this ported this PR can also remove the TorchMathTestMeta for sqrt in test_torch.py (see torch_op_tests, I can't link the line because the file is too big to render on Github).

ref=np.sqrt,
domain=(0, float('inf')),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metanote: it's a little unfortunate that we don't have dtype-specific domains so we can't test taking the sqrt of negative complex values easily.

dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 7e-2}),),
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/47358
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_MACOS),
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
# RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype.
SkipInfo('TestGradients', 'test_fn_grad',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metanote: if this happens again we should add a property for whether a function supports complex autograd or not, and, if complex autograd isn't supported, skip the complex autograd tests.

dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_fn_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_gradgrad',
dtypes=[torch.cdouble]),),
promotes_integers_to_float=True,
handles_complex_extremals=False),
]

# Common operator groupings
Expand Down