Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement copysign #46396

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
310bc5e
Implement copysign
ejguan Oct 15, 2020
c4d9155
Update on "[WIP] Implement copysign"
ejguan Oct 15, 2020
f04d5cd
Update on "[WIP] Implement copysign"
ejguan Oct 15, 2020
acde2be
Update on "[WIP] Implement copysign"
ejguan Oct 15, 2020
cc8db7e
Update on "[WIP] Implement copysign"
ejguan Oct 16, 2020
74569e2
Update on "[WIP] Implement copysign"
ejguan Oct 16, 2020
f9ff9f0
Update on "[WIP] Implement copysign"
ejguan Oct 16, 2020
1678724
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
4c9af86
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
26ec525
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
46f7102
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
9cbe66b
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
493ea3e
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
7c22820
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
79b56c2
Update on "[WIP] Implement copysign"
ejguan Oct 19, 2020
35367dd
Update on "[WIP] Implement copysign"
ejguan Oct 20, 2020
5ed0a2c
Update on "[WIP] Implement copysign"
ejguan Oct 20, 2020
3bad25c
Update on "[WIP] Implement copysign"
ejguan Oct 20, 2020
5588954
Update on "Implement copysign"
ejguan Oct 20, 2020
0bf84f1
Update on "Implement copysign"
ejguan Oct 20, 2020
cc21795
Update on "Implement copysign"
ejguan Oct 21, 2020
61915d5
Update on "Implement copysign"
ejguan Oct 21, 2020
8462fe9
Update on "Implement copysign"
ejguan Oct 21, 2020
4e092ee
Update on "Implement copysign"
ejguan Oct 21, 2020
46698db
Update on "Implement copysign"
ejguan Oct 21, 2020
75fde66
Update on "Implement copysign"
ejguan Oct 21, 2020
e6ab664
Update on "Implement copysign"
ejguan Oct 21, 2020
155655c
Update on "Implement copysign"
ejguan Oct 22, 2020
b62da43
Update on "Implement copysign"
ejguan Oct 23, 2020
5391b39
Update on "Implement copysign"
ejguan Oct 23, 2020
c35f0ff
Update on "Implement copysign"
ejguan Oct 23, 2020
ac5dc13
Update on "Implement copysign"
ejguan Oct 27, 2020
f4dc913
Update on "Implement copysign"
ejguan Oct 27, 2020
abc20c4
Update on "Implement copysign"
ejguan Oct 28, 2020
78b8b6b
Update on "Implement copysign"
ejguan Oct 29, 2020
01392fb
Update on "Implement copysign"
ejguan Nov 2, 2020
c6e3e8a
Update on "Implement copysign"
ejguan Nov 2, 2020
e3e5eed
Update on "Implement copysign"
ejguan Nov 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -239,6 +239,7 @@ _(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -48,6 +48,7 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -121,6 +122,23 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
ejguan marked this conversation as resolved.
Show resolved Hide resolved
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor copysign(const Tensor& self, Scalar other) {
return at::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -70,5 +70,6 @@ DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);

}} // namespace at::native
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -773,6 +773,14 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "copysign_cpu", [&]() {
ejguan marked this conversation as resolved.
Show resolved Hide resolved
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
});
});
ejguan marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -812,6 +820,7 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);

} // namespace native
} // namespace at
12 changes: 12 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
@@ -1,10 +1,13 @@
#include <ATen/AccumulateType.h>
ejguan marked this conversation as resolved.
Show resolved Hide resolved
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>

#include <c10/cuda/CUDAMathCompat.h>


// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
Expand Down Expand Up @@ -108,6 +111,14 @@ void heaviside_kernel_cuda(TensorIterator& iter) {
});
}

void copysign_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "copysign_cuda", [&]() {
ejguan marked this conversation as resolved.
Show resolved Hide resolved
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return c10::cuda::compat::copysign(a, b);
});
});
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
Expand All @@ -118,5 +129,6 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel_cuda);

}} // namespace at::native
16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -843,6 +843,22 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
ejguan marked this conversation as resolved.
Show resolved Hide resolved
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign
ejguan marked this conversation as resolved.
Show resolved Hide resolved

- func: logical_not(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
7 changes: 7 additions & 0 deletions c10/cuda/CUDAMathCompat.h
Expand Up @@ -42,6 +42,13 @@ __MATH_FUNCTIONS_DECL__ double ceil(double x) {
return ::ceil(x);
}

__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
return ::copysignf(x, y);
}
__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
return ::copysign(x, y);
}

__MATH_FUNCTIONS_DECL__ float floor(float x) {
return ::floorf(x);
}
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Expand Up @@ -247,6 +247,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: contiguous
.. automethod:: copy_
.. automethod:: conj
.. automethod:: copysign
.. automethod:: cos
.. automethod:: cos_
.. automethod:: cosh
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -277,6 +277,7 @@ Pointwise Ops
clamp
clip
conj
copysign
cos
cosh
deg2rad
Expand Down
37 changes: 37 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6709,6 +6709,43 @@ def test_GRU_grad_and_gradgrad(self, device):
mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(torch.float64)
self._test_rnn_mod(mod, inp)

def test_copysign_grad(self, device):
ejguan marked this conversation as resolved.
Show resolved Hide resolved
# Input is 0
x = torch.tensor([0, 0, 0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-1, 0, 1], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0., 0., 0.])
self.assertEqual(y.grad.tolist(), [0.] * 3)

# Input is -0
x = torch.tensor([0, 0, 0], dtype=torch.float, device=device)
x = - torch.abs(x)
x.requires_grad_()
y = torch.tensor([-1, 0, 1], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0., 0., 0.])
self.assertEqual(y.grad.tolist(), [0.] * 3)

# Other is 0
x = torch.tensor([-1, 0, 1], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([0, 0, 0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [-1., 0., 1.])
self.assertEqual(y.grad.tolist(), [0.] * 3)

# Other is -0
x = torch.tensor([-1, 0, 1], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([0, 0, 0], dtype=torch.float, device=device)
y = - torch.abs(y)
ejguan marked this conversation as resolved.
Show resolved Hide resolved
y.requires_grad_()
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [1., 0., -1.])
self.assertEqual(y.grad.tolist(), [0.] * 3)

@deviceCountAtLeast(1)
def test_grad_assignment(self, devices):
x = torch.randn(5, 5, device=devices[0])
Expand Down
70 changes: 70 additions & 0 deletions test/test_torch.py
Expand Up @@ -17224,6 +17224,76 @@ def test_reciprocal_complex_extremal(self, device, dtype):

self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype)

def test_copysign_bool(self, device):
ejguan marked this conversation as resolved.
Show resolved Hide resolved
# Promoted to float
# Only 0 and 1
x = torch.tensor([0, 0, 1, 1], device=device, dtype=torch.bool)
y = torch.tensor([0, 1, 0, 1], device=device, dtype=torch.bool)
res = torch.tensor([0, 0, 1, 1], device=device, dtype=torch.float)
self.assertEqual(torch.copysign(x, y), res)

def _test_copysign(self, device, dtype, result_dtype):
# Scalar
x = torch.tensor([-1, 0, -0, 1], dtype=dtype)
ejguan marked this conversation as resolved.
Show resolved Hide resolved
y = torch.tensor(-1.)
res = torch.tensor([-1, -0, -0, -1], dtype=result_dtype)
self.assertEqual(torch.copysign(x, y), res)
zou3519 marked this conversation as resolved.
Show resolved Hide resolved

x = torch.tensor(-1.)
y = torch.tensor([-1, 0, -0, 1], dtype=dtype)
res = torch.tensor([-1, 1, 1, 1], dtype=result_dtype)
self.assertEqual(torch.copysign(x, y), res)

# Constant
x = torch.tensor([-1, 0, -0, 1], dtype=dtype)
res = torch.tensor([-1, -0, -0, -1], dtype=result_dtype)
self.assertEqual(torch.copysign(x, -1.), res)
ejguan marked this conversation as resolved.
Show resolved Hide resolved

# Normal
res = torch.tensor([[1, 0, 0, 1] * 2, [-1, -0, -0, -1] * 2], dtype=result_dtype)
x = torch.tensor([-1, -0, 0, 1] * 2, dtype=dtype)
y = torch.tensor([[0] * 4, [1] * 4], dtype=dtype).reshape(-1)
self.assertEqual(torch.copysign(x, y), res[0])

y = torch.copysign(y, -1)
self.assertEqual(torch.copysign(x, y), res[1])

# Broadcast
res = torch.tensor([[[1, 0, 0, 1]] * 2, [[-1, -0, -0, -1]] * 2], dtype=result_dtype)

# LHS
x = torch.tensor([-1, -0, 0, 1] * 2, dtype=dtype).reshape(2, 1, 4)
y = torch.tensor([[0] * 4, [1] * 4], dtype=dtype)
z = torch.copysign(x, y)
for i in range(2):
self.assertEqual(z[i], res[0])

y = torch.copysign(y, -1)
z = torch.copysign(x, y)
for i in range(2):
self.assertEqual(z[i], res[1])

# RHS
x = torch.tensor([-1, 0, -0, 1] * 2, dtype=dtype).reshape(2, 4)
y = torch.tensor([[0] * 4, [1] * 4], dtype=dtype).reshape(2, 1, 4)
z = torch.copysign(x, y)
for i in range(2):
self.assertEqual(z[i], res[0])

y = torch.copysign(y, -1)
z = torch.copysign(x, y)
for i in range(2):
self.assertEqual(z[i], res[1])
ejguan marked this conversation as resolved.
Show resolved Hide resolved

@dtypes(torch.half, torch.float, torch.double)
def test_copysign_floating(self, device, dtype):
self._test_copysign(device, dtype, dtype)

@dtypes(torch.int8, torch.short, torch.int, torch.long)
def test_copysign_integral(self, device, dtype):
ejguan marked this conversation as resolved.
Show resolved Hide resolved
# Promoted to float
self._test_copysign(device, dtype, torch.float)

@dtypes(torch.bfloat16, torch.float)
def test_div(self, device, dtype):
for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_),
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -340,6 +340,13 @@
- name: _conj(Tensor self) -> Tensor
self: grad.conj()

- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other)

- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)

- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()

Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -8,6 +8,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/Dtype.h"
Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -2,6 +2,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
ejguan marked this conversation as resolved.
Show resolved Hide resolved
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
Expand Down
13 changes: 13 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -912,6 +912,19 @@ def add_docstr_all(method, docstr):
See :func:`torch.conj`
""")

add_docstr_all('copysign',
r"""
copysign(other) -> Tensor

Returns a new tensor with the magnitude of :attr:`input` and the sign of :attr:`other` elementwise.
ejguan marked this conversation as resolved.
Show resolved Hide resolved

When :attr:`other` is a tensor, the shape of :attr:`other` must be
:ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
tensor

See :func:`torch.copysign`
""")

add_docstr_all('cos',
r"""
cos() -> Tensor
Expand Down
49 changes: 49 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -1870,6 +1870,55 @@ def merge_dicts(*dicts):
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
""".format(**common_args))

add_docstr(torch.copysign,
r"""
copysign(input, other, *, out=None) -> Tensor

Returns a new tensor with the magnitude of :attr:`input` and the sign of :attr:`other` elementwise.
ejguan marked this conversation as resolved.
Show resolved Hide resolved

.. math::
\text{out}_{i} = \begin{cases}
ejguan marked this conversation as resolved.
Show resolved Hide resolved
-|\text{input}_{i}| & \text{if } \text{other}_{i} < 0 \Vert \text{other}_{i} == -0 \\
ejguan marked this conversation as resolved.
Show resolved Hide resolved
|\text{input}_{i}| & \text{if } \text{other}_{i} \gt 0 \Vert \text{other}_{i} == 0\\
ejguan marked this conversation as resolved.
Show resolved Hide resolved
\end{cases}
""" + r"""

Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
and integer, float inputs. Always promotes integral/boolean/half/float16 types to the float type.
ejguan marked this conversation as resolved.
Show resolved Hide resolved

Args:
input (Tensor): value with the magnitude of output value.
ejguan marked this conversation as resolved.
Show resolved Hide resolved
other (Tensor or Number): value with the sign of output value.
ejguan marked this conversation as resolved.
Show resolved Hide resolved
If the shape of :attr:`input` is different from the shape of
:attr:`other`, they must be broadcastable to a common shape
(which becomes the shape of the output).

Keyword args:
{out}

Example::

>>> a = torch.randn(5)
>>> a
tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244])
>>> torch.copysign(a, 1)
tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244])
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.7079, 0.2778, -1.0249, 0.5719],
[-0.0059, -0.2600, -0.4475, -1.3948],
[ 0.3667, -0.9567, -2.5757, -0.1751],
[ 0.2046, -0.0742, 0.2998, -0.1054]])
>>> b = torch.randn(4)
tensor([ 0.2373, 0.3120, 0.3190, -1.1128])
>>> torch.copysign(a, b)
tensor([[ 0.7079, 0.2778, 1.0249, -0.5719],
[ 0.0059, 0.2600, 0.4475, -1.3948],
[ 0.3667, 0.9567, 2.5757, -0.1751],
[ 0.2046, 0.0742, 0.2998, -0.1054]])

""".format(**common_args))

add_docstr(torch.cos,
r"""
cos(input, *, out=None) -> Tensor
Expand Down