diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index e84ad93de37d..4a1aa4e9f0d2 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -239,6 +239,7 @@ _(aten, combinations) \ _(aten, _conj) \ _(aten, conj) \ _(aten, complex) \ +_(aten, copysign) \ _(aten, polar) \ _(aten, constant_pad_nd) \ _(aten, contiguous) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index b7916ba3f9c8..abef9d3946eb 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -49,6 +49,7 @@ DEFINE_DISPATCH(hypot_stub); DEFINE_DISPATCH(igamma_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); +DEFINE_DISPATCH(copysign_stub); static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -122,6 +123,31 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) { return add_relu_impl(self, self, other, alpha); } +Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return result; +} + +Tensor copysign(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& copysign_(Tensor& self, const Tensor& other) { + return native::copysign_out(self, self, other); +} + +Tensor copysign(const Tensor& self, Scalar other) { + return native::copysign(self, wrapped_scalar_tensor(other)); +} + +Tensor& copysign_(Tensor& self, Scalar other) { + return native::copysign_(self, wrapped_scalar_tensor(other)); +} + Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_float_op(result, self, other); div_stub(iter.device_type(), iter); diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index ee3f023fedc5..fa4b3c3d659c 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -71,5 +71,6 @@ DECLARE_DISPATCH(binary_fn, hypot_stub); DECLARE_DISPATCH(binary_fn, igamma_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); +DECLARE_DISPATCH(binary_fn, copysign_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 652f3ee063e1..57fe4555e46d 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -800,6 +800,14 @@ void heaviside_kernel(TensorIterator& iter) { }); } +void copysign_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() { + cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { + return std::copysign(a, b); + }); + }); +} + } // namespace REGISTER_DISPATCH(add_stub, &add_kernel); @@ -840,6 +848,7 @@ REGISTER_DISPATCH(hypot_stub, &hypot_kernel); REGISTER_DISPATCH(igamma_stub, &igamma_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); +REGISTER_DISPATCH(copysign_stub, ©sign_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index 2f53c2bb08d7..f02a789c49d0 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -5,6 +5,15 @@ #include #include +#if defined(__CUDACC__) +#include +#include +#include +#elif defined(__HIPCC__) +#include +#include +#include +#endif // NOTE: CUDA on Windows requires that the enclosing function // of a __device__ lambda not have internal linkage. @@ -116,6 +125,14 @@ void heaviside_kernel_cuda(TensorIterator& iter) { }); } +void copysign_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::cuda::compat::copysign(a, b); + }); + }); +} + REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda); REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda); @@ -127,5 +144,6 @@ REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); +REGISTER_DISPATCH(copysign_stub, ©sign_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c98b9423b25c..0362e281a959 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -843,6 +843,34 @@ dispatch: CPU, CUDA: bitwise_not_out +- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: copysign + +- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + CPU, CUDA: copysign_ + +- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: copysign_out + +- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: full + variants: function, method + dispatch: + CPU, CUDA: copysign + +- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + dispatch: + CPU, CUDA: copysign_ + - func: logical_not(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/c10/cuda/CUDAMathCompat.h b/c10/cuda/CUDAMathCompat.h index 7652ca0f639d..1fb0c3ec29c2 100644 --- a/c10/cuda/CUDAMathCompat.h +++ b/c10/cuda/CUDAMathCompat.h @@ -42,6 +42,13 @@ __MATH_FUNCTIONS_DECL__ double ceil(double x) { return ::ceil(x); } +__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) { + return ::copysignf(x, y); +} +__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) { + return ::copysign(x, y); +} + __MATH_FUNCTIONS_DECL__ float floor(float x) { return ::floorf(x); } diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 3bc806067870..5ae801685587 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -249,6 +249,8 @@ view of a storage and defines numeric operations on it. .. automethod:: contiguous .. automethod:: copy_ .. automethod:: conj + .. automethod:: copysign + .. automethod:: copysign_ .. automethod:: cos .. automethod:: cos_ .. automethod:: cosh diff --git a/docs/source/torch.rst b/docs/source/torch.rst index e36a3f944a7a..ad03897efd11 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -279,6 +279,7 @@ Pointwise Ops clamp clip conj + copysign cos cosh deg2rad diff --git a/test/test_autograd.py b/test/test_autograd.py index 5dc5c94c3e53..6eee37c557b8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6713,6 +6713,39 @@ def test_GRU_grad_and_gradgrad(self, device): mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(torch.float64) self._test_rnn_mod(mod, inp) + def test_copysign_subgradient(self, device): + # Input is 0.0 + x = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Input is -0.0 + x = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Other is 0.0 + x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + + # Other is -0.0 + x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True) + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0]) + self.assertEqual(y.grad.tolist(), [0.0] * 3) + @deviceCountAtLeast(1) def test_grad_assignment(self, devices): x = torch.randn(5, 5, device=devices[0]) diff --git a/test/test_torch.py b/test/test_torch.py index d927446140cf..ab61f5c7882f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -17292,6 +17292,67 @@ def test_reciprocal_complex_extremal(self, device, dtype): self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), + torch.testing.get_all_dtypes(include_complex=False))) + def test_copysign(self, device, dtypes): + def _test_copysign_numpy(a, b): + torch_result = torch.copysign(a, b) + + if a.dtype == torch.bfloat16: + np_a = a.to(torch.float).cpu().numpy() + else: + np_a = a.cpu().numpy() + + if b.dtype == torch.bfloat16: + np_b = b.to(torch.float).cpu().numpy() + else: + np_b = b.cpu().numpy() + expected = torch.from_numpy(np.copysign(np_a, np_b)) + # To handle inconsistencies of type promotion between PyTorch and Numpy + # Applied for both arguments having integral precision and bfloat16 + types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes() + if a.dtype in types or b.dtype in types: + promoted_type = torch.promote_types(torch_result.dtype, expected.dtype) + torch_result = torch_result.to(promoted_type) + expected = expected.to(promoted_type) + + # Verify Value + self.assertEqual(torch_result, expected) + # Verify Sign + # Use double copysign to verify the correctnes of 0.0 and -0.0, since + # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the + # magnitude to verify the sign between torch and numpy results, elementwise. + self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result), + torch.copysign(torch.tensor(1.0), expected)) + + # Compare Result with NumPy + # Type promotion + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + # Broadcast + a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9) + _test_copysign_numpy(a, b) + + # 0.0/-0.0/inf/-inf/nan + cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')] + if dtypes[0] in torch.testing.get_all_fp_dtypes(include_bfloat16=False): + b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) + for case in cases: + _test_copysign_numpy(torch.tensor([case], dtype=dtypes[0]), b) + + if dtypes[1] in torch.testing.get_all_fp_dtypes(): + a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) + for case in cases: + _test_copysign_numpy(a, torch.tensor([case], dtype=dtypes[1])) + @dtypes(torch.bfloat16, torch.float) def test_div(self, device, dtype): for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8cbcab35685e..b68b45ad1112 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -340,6 +340,13 @@ - name: _conj(Tensor self) -> Tensor self: grad.conj() +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + other: zeros_like(other) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + - name: cos(Tensor self) -> Tensor self: grad * -self.sin().conj() diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 5f01f4859c51..e05e6fbe1975 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -8,6 +8,12 @@ #include +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/Dtype.h" diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 3fd833f3af9f..4b441d6f3616 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -2,6 +2,12 @@ #include +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/Exceptions.h" #include "torch/csrc/Size.h" diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 12dd77497454..182adf8ba725 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -939,6 +939,19 @@ def add_docstr_all(method, docstr): See :func:`torch.conj` """) +add_docstr_all('copysign', + r""" +copysign(other) -> Tensor + +See :func:`torch.copysign` +""") + +add_docstr_all('copysign_', r""" +copysign_(other) -> Tensor + +In-place version of :meth:`~Tensor.copysign` +""") + add_docstr_all('cos', r""" cos() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4f0de3335414..99074f8182a2 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1937,6 +1937,53 @@ def merge_dicts(*dicts): tensor([-1 - 1j, -2 - 2j, 3 + 3j]) """.format(**common_args)) +add_docstr(torch.copysign, + r""" +copysign(input, other, *, out=None) -> Tensor + +Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + +.. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if} \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if} \text{other}_{i} \geq 0.0 \\ + \end{cases} +""" + r""" + +Supports :ref:`broadcasting to a common shape `, +and integer and float inputs. + +Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + +""".format(**common_args)) + add_docstr(torch.cos, r""" cos(input, *, out=None) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index d3752bce04cc..aed5dd52999d 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -49,6 +49,12 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { std::copy(t.begin(), t.end(), out.begin() + range.first); } +Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) { + auto ratio = result / self; + ratio.masked_fill_(self == 0, 0); + return grad * ratio; +} + Tensor not_implemented(const char* name) { throw std::runtime_error( std::string("the derivative for '") + name + "' is not implemented"); diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 46f26610c127..6fd6d6bc418b 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -32,6 +32,7 @@ struct IndexRangeGenerator { bool any_variable_defined(variable_list& variables); void copy_range(variable_list& out, IndexRange range, const at::Tensor & t); void copy_range(variable_list& out, IndexRange range, at::ArrayRef t); +at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result); at::Tensor not_implemented(const char* name); at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result); at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s); diff --git a/torch/overrides.py b/torch/overrides.py index 4fd334c911d8..1cb721825f44 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -283,6 +283,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.clone: lambda input: -1, torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, + torch.copysign: lambda input, other, out=None: -1, torch.polar: lambda abs, ang: -1, torch.conj: lambda input, out=None: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 9b67d31e7d16..61e7b420a535 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -706,6 +706,15 @@ def method_tests(): ('cosh', (S, S, S), NO_ARGS, '', (True,)), ('cosh', (), NO_ARGS, 'scalar', (True,)), ('conj', (S, S, S), NO_ARGS), + ('copysign', (S, S, S), ((S, S, S),), '', (False,)), + ('copysign', (S, S, S), ((S, S),), 'broadcast_rhs', (False,)), + ('copysign', (S, S), ((S, S, S),), 'broadcast_lhs', (False,)), + ('copysign', (S, 1, S), ((M, S),), 'broadcast_all', (False,)), + ('copysign', (S, S), (3.14,), 'scalar', (False,)), + ('copysign', (S, S), (0.0,), 'scalar_pos_zero', (False,)), + # TorchScript does not recognize -0.0: Issue #46848 + # https://github.com/pytorch/pytorch/issues/46848 + # ('copysign', (S, S), (-0.0,), 'scalar_neg_zero', (False,)), ('real', (S, S, S), NO_ARGS, 'complex'), ('imag', (S, S, S), NO_ARGS, 'complex'), ('view_as_real', (S, S, S), NO_ARGS, 'complex'),