Skip to content

Commit

Permalink
[numpy] torch.sqrt : promote integer inputs to float (#47293)
Browse files Browse the repository at this point in the history
Summary:
Reference #42515

Pull Request resolved: #47293

Reviewed By: malfet

Differential Revision: D24855994

Pulled By: mruberry

fbshipit-source-id: 1e6752f2eeba6d638dea0bdea0c650cf722718c9
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Nov 13, 2020
1 parent 7391edb commit 3649a2c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -378,8 +378,8 @@ Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(r
Tensor arctanh(const Tensor& self) { return self.atanh(); }
Tensor& arctanh_(Tensor& self) { return self.atanh_(); }

Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); }
Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); }
Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sqrt_stub); }
Tensor sqrt(const Tensor& self) { return unary_op_impl_float(self, sqrt_stub); }
Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); }

Tensor square(const Tensor& self) { return at::pow(self, 2); }
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Expand Up @@ -78,7 +78,7 @@ __host__ __device__ static inline c10::complex<T> rsqrt_wrapper(c10::complex<T>
}

void rsqrt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "rsqrt_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "rsqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
return rsqrt_wrapper(a);
Expand All @@ -87,7 +87,7 @@ void rsqrt_kernel_cuda(TensorIterator& iter) {
}

void sqrt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "sqrt_cuda", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "sqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::sqrt(a);
});
Expand Down
3 changes: 1 addition & 2 deletions test/test_torch.py
Expand Up @@ -20921,8 +20921,7 @@ def __init__(self,
self.dtypes = dtypes
self.replace_inf_with_nan = replace_inf_with_nan

torch_op_tests = [_TorchMathTestMeta('sqrt'),
_TorchMathTestMeta('erf', ref_backend='scipy'),
torch_op_tests = [_TorchMathTestMeta('erf', ref_backend='scipy'),
_TorchMathTestMeta('erfc', ref_backend='scipy'),
_TorchMathTestMeta('exp'),
_TorchMathTestMeta('expm1'),
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -1146,8 +1146,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::sqrt: {
return computeOneOperand(
"aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); });
return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) {
return sqrt(promoteIntegerToFloat(a));
});
} break;

case aten::rsqrt: {
Expand Down
32 changes: 31 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -436,7 +436,37 @@ def sample_inputs(self, device, dtype, requires_grad=False):
ref=np.nan_to_num,
dtypes=all_types_and(torch.half, torch.bool),
dtypesIfCPU=None,
dtypesIfCUDA=None)
dtypesIfCUDA=None),
UnaryUfuncInfo('sqrt',
ref=np.sqrt,
domain=(0, float('inf')),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 7e-2}),),
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/47358
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
active_if=IS_MACOS),
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
# RuntimeError: sqrt does not support automatic differentiation for outputs with complex dtype.
SkipInfo('TestGradients', 'test_fn_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_fn_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_method_gradgrad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_grad',
dtypes=[torch.cdouble]),
SkipInfo('TestGradients', 'test_inplace_gradgrad',
dtypes=[torch.cdouble]),),
promotes_integers_to_float=True,
handles_complex_extremals=False),
]

# Common operator groupings
Expand Down

0 comments on commit 3649a2c

Please sign in to comment.