From 0fb92d976e5c391d4748624c14f5d48f28bf6667 Mon Sep 17 00:00:00 2001 From: ErjiaGuan Date: Fri, 16 Oct 2020 18:56:20 -0400 Subject: [PATCH] Implement copysign (fix docs) ghstack-source-id: edccfde9095eab988ba254c001b784021fa6fe87 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46396 --- aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/BinaryOps.cpp | 18 ++++++ aten/src/ATen/native/BinaryOps.h | 1 + aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 9 +++ .../ATen/native/cuda/BinaryMiscOpsKernels.cu | 35 +++++++++++ aten/src/ATen/native/native_functions.yaml | 16 +++++ docs/source/torch.rst | 1 + test/test_autograd.py | 9 +++ test/test_torch.py | 60 +++++++++++++++++++ tools/autograd/derivatives.yaml | 7 +++ torch/_torch_docs.py | 49 +++++++++++++++ torch/csrc/autograd/FunctionsManual.cpp | 5 ++ torch/csrc/autograd/FunctionsManual.h | 1 + torch/overrides.py | 1 + 14 files changed, 213 insertions(+) diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index da259e82990a..830647386fe4 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -239,6 +239,7 @@ _(aten, combinations) \ _(aten, _conj) \ _(aten, conj) \ _(aten, complex) \ +_(aten, copysign) \ _(aten, polar) \ _(aten, constant_pad_nd) \ _(aten, contiguous) \ diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f8af756773c9..f02b436812ad 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -48,6 +48,7 @@ DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); +DEFINE_DISPATCH(copysign_stub); static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -121,6 +122,23 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) { return add_relu_impl(self, self, other, alpha); } +Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return result; +} + +Tensor copysign(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_float_op(result, self, other); + copysign_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor copysign(const Tensor& self, Scalar other) { + return at::copysign(self, wrapped_scalar_tensor(other)); +} + Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_float_op(result, self, other); div_stub(iter.device_type(), iter); diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 7640c8bd84ac..0cc1fb719cee 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -70,5 +70,6 @@ DECLARE_DISPATCH(binary_fn, lcm_stub); DECLARE_DISPATCH(binary_fn, hypot_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); +DECLARE_DISPATCH(binary_fn, copysign_stub); }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index fce8c348919b..59b81d7b947f 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -773,6 +773,14 @@ void heaviside_kernel(TensorIterator& iter) { }); } +void copysign_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "copysign_cpu", [&]() { + cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { + return std::copysign(a, b); + }); + }); +} + } // namespace REGISTER_DISPATCH(add_stub, &add_kernel); @@ -812,6 +820,7 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel); REGISTER_DISPATCH(hypot_stub, &hypot_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); +REGISTER_DISPATCH(copysign_stub, ©sign_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index fc9aa74f91f4..9c098f7e9eee 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -108,6 +108,40 @@ void heaviside_kernel_cuda(TensorIterator& iter) { }); } +template +struct CopySignScalarFunctor { + MulScalarFunctor(accscalar_t b_): b(b_) {} + __device__ scalar_t operator() (scalar_t a) const { + return ::copysign(a, b); + } + private: + accscalar_t b; +}; + +template +struct CopySignFunctor { + __device__ scalar_t operator() (scalar_t a, scalar_t b) const { + return ::copysign(a, b); + } +}; + +void copysign_kernel_cuda(TensorIterator& iter) { + if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "copysign_cuda", [&]() { + using accscalar_t = at::acc_type; + auto b = iter.scalar_value(2); + iter.remove_operand(2); + CopySignScalarFunctor f(b); + gpu_kernel(iter, f); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "copysign_cuda", [&]() { + CopySignFunctor f; + gpu_kernel_with_scalars(iter, f); + }); + } +} + REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda); REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda); REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda); @@ -118,5 +152,6 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); +REGISTER_DISPATCH(copysign_stub, ©sign_kernel_cuda); }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3de4079a7d20..8756c2b0baab 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -843,6 +843,22 @@ dispatch: CPU, CUDA: bitwise_not_out +- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + CPU, CUDA: copysign + +- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: copysign_out + +- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: copysign + - func: logical_not(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/docs/source/torch.rst b/docs/source/torch.rst index b3c8410300c6..360a316b8b95 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -277,6 +277,7 @@ Pointwise Ops clamp clip conj + copysign cos cosh deg2rad diff --git a/test/test_autograd.py b/test/test_autograd.py index 7235abbb22b5..0861f05767fa 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6749,6 +6749,15 @@ def test_copy_(self, device): z = x.to(torch.bfloat16) self.assertTrue(z.requires_grad) + def test_copysign(self, device): + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float, device=device, requires_grad=True) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float, device=device).reshape(-1) + y.requires_grad_() + out = torch.copysign(x, y) + out.sum().backward() + self.assertEqual(x.grad.tolist(), [1., 0., 0., -1.] + [-1., 0., 0., 1.] * 3) + self.assertEqual(y.grad.tolist(), [0.] * 16) + @onlyCUDA def test_simple_reentrant_cross_device(self, device): class ReentrantFunc(Function): diff --git a/test/test_torch.py b/test/test_torch.py index 2e7ae4edbbae..660beace88e4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -660,6 +660,66 @@ def test_copy_transpose(self): self.assertEqual(y[:, 0], range(100)) self.assertEqual(y[:, 40], range(4000, 4100)) + def test_copysign(self): + # Float or Integral promoted to Float + res = torch.tensor([-1, -0, -0, -1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1], dtype=torch.float) + types = [torch.short, torch.int, torch.long, torch.float] + for t in types: + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=t) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=t).reshape(-1) + z = torch.copysign(x, y) + self.assertEqual(z, res) + + # Half + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.half) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.half).reshape(-1) + res = torch.tensor([-1, -0, -0, -1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1], dtype=torch.half) + self.assertEqual(torch.copysign(x, y), res) + + # Bool + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.bool) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.bool).reshape(-1) + res = torch.tensor([1, 0, 0, 1] * 4, dtype=torch.float).reshape(-1) + self.assertEqual(torch.copysign(x, y), res) + + # Broadcast + res = torch.tensor([[-1, -0, -0, -1], [1, 0, 0, 1], [1, 0, 0, 1], [1, 0, 0, 1]], dtype=torch.float) + + # LHS + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(4, 1, 4) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float) + z = torch.copysign(x, y) + for i in range(4): + self.assertEqual(z[i], res) + + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(1, 4, 4) + z = torch.copysign(x, y) + self.assertEqual(z, res.reshape(1, 4, 4)) + + # RHS + x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(4, 4) + y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float).reshape(4, 1, 4) + z = torch.copysign(x, y) + for i in range(4): + for j in range(4): + self.assertEqual(z[i][j], res[i]) + + # Scalar + x = torch.tensor([-1, 0, -0, 1], dtype=torch.float) + y = torch.tensor(-1.) + res = torch.tensor([-1, -0, -0, -1], dtype=torch.float) + self.assertEqual(torch.copysign(x, y), res) + + x = torch.tensor(-1.) + y = torch.tensor([-1, 0, -0, 1], dtype=torch.float) + res = torch.tensor([-1, 1, 1, 1], dtype=torch.float) + self.assertEqual(torch.copysign(x, y), res) + + # Constant + x = torch.tensor([-1, 0, -0, 1], dtype=torch.float) + res = torch.tensor([-1, -0, -0, -1], dtype=torch.float) + self.assertEqual(torch.copysign(x, -1.), res) + def test_device(self): cpu = torch.device('cpu') self.assertEqual('cpu', str(cpu)) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index eddc30a8604b..2e566c39f153 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -340,6 +340,13 @@ - name: _conj(Tensor self) -> Tensor self: grad.conj() +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, other) + other: zeros_like(other) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, at::scalar_to_tensor(other)) + - name: cos(Tensor self) -> Tensor self: grad * -self.sin().conj() diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 54c5915593d8..9f5bd4184757 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1870,6 +1870,55 @@ def merge_dicts(*dicts): tensor([-1 - 1j, -2 - 2j, 3 + 3j]) """.format(**common_args)) +add_docstr(torch.copysign, + r""" +copysign(input, other, *, out=None) -> Tensor + +Returns a new tensor with the magnitude of :attr:`input` and the sign of :attr:`other` elementwise. + +.. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} < 0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0 \\ + \end{cases} +""" + r""" + +Supports :ref:`broadcasting to a common shape `, +and integer, float inputs. Always promotes integral/boolean/half/float16 types to the float type. + +Args: + input (Tensor): value with the magnitude of output value. + other (Tensor or Number): value with the sign of output value. + If the shape of :attr:`input` is different from the shape of + :attr:`other`, they must be broadcastable to a common shape + (which becomes the shape of the output). + +Keyword args: + {out} + +Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + +""".format(**common_args)) + add_docstr(torch.cos, r""" cos(input, *, out=None) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f807e2b1dff0..58ff71573e09 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -49,6 +49,11 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { std::copy(t.begin(), t.end(), out.begin() + range.first); } +Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & other) { + auto result = grad * self.sign() * (other.ge(0) * 2 - 1); + return result; +} + Tensor not_implemented(const char* name) { throw std::runtime_error( std::string("the derivative for '") + name + "' is not implemented"); diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 00171cbbf656..b6c275f5778d 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -32,6 +32,7 @@ struct IndexRangeGenerator { bool any_variable_defined(variable_list& variables); void copy_range(variable_list& out, IndexRange range, const at::Tensor & t); void copy_range(variable_list& out, IndexRange range, at::ArrayRef t); +at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & other); at::Tensor not_implemented(const char* name); at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result); at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s); diff --git a/torch/overrides.py b/torch/overrides.py index 53a8b39c8e31..2a125bdb48ea 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -285,6 +285,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.clone: lambda input: -1, torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, + torch.copysign: lambda input, other, out=None: -1, torch.polar: lambda abs, ang: -1, torch.conj: lambda input, out=None: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1,