Skip to content

Commit

Permalink
Implement copysign (Fix Test)
Browse files Browse the repository at this point in the history
ghstack-source-id: c1a5bac31d691283301d70fa185e689f923b090a
Pull Request resolved: #46396
  • Loading branch information
ejguan committed Oct 29, 2020
1 parent ecdbea7 commit fb57aba
Show file tree
Hide file tree
Showing 20 changed files with 275 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -239,6 +239,7 @@ _(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -48,6 +48,7 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -121,6 +122,31 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& copysign_(Tensor& self, const Tensor& other) {
return native::copysign_out(self, self, other);
}

Tensor copysign(const Tensor& self, Scalar other) {
return native::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& copysign_(Tensor& self, Scalar other) {
return native::copysign_(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -70,5 +70,6 @@ DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);

}} // namespace at::native
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -787,6 +787,14 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
});
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -826,6 +834,7 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);

} // namespace native
} // namespace at
18 changes: 18 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -5,6 +5,15 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>

#if defined(__CUDACC__)
#include <cuda.h>
#include <cuda_fp16.h>
#include <c10/cuda/CUDAMathCompat.h>
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <c10/hip/HIPMathCompat.h>
#endif

// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
Expand Down Expand Up @@ -108,6 +117,14 @@ void heaviside_kernel_cuda(TensorIterator& iter) {
});
}

void copysign_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return c10::cuda::compat::copysign(a, b);
});
});
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
Expand All @@ -118,5 +135,6 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel_cuda);

}} // namespace at::native
28 changes: 28 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -843,6 +843,34 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign

- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: copysign_

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign

- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: copysign_

- func: logical_not(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
7 changes: 7 additions & 0 deletions c10/cuda/CUDAMathCompat.h
Expand Up @@ -42,6 +42,13 @@ __MATH_FUNCTIONS_DECL__ double ceil(double x) {
return ::ceil(x);
}

__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
return ::copysignf(x, y);
}
__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
return ::copysign(x, y);
}

__MATH_FUNCTIONS_DECL__ float floor(float x) {
return ::floorf(x);
}
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -249,6 +249,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: contiguous
.. automethod:: copy_
.. automethod:: conj
.. automethod:: copysign
.. automethod:: copysign_
.. automethod:: cos
.. automethod:: cos_
.. automethod:: cosh
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -277,6 +277,7 @@ Pointwise Ops
clamp
clip
conj
copysign
cos
cosh
deg2rad
Expand Down
33 changes: 33 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6705,6 +6705,39 @@ def test_GRU_grad_and_gradgrad(self, device):
mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(torch.float64)
self._test_rnn_mod(mod, inp)

def test_copysign_subgradient(self, device):
# Input is 0.0
x = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Input is -0.0
x = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Other is 0.0
x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Other is -0.0
x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

@deviceCountAtLeast(1)
def test_grad_assignment(self, devices):
x = torch.randn(5, 5, device=devices[0])
Expand Down
54 changes: 54 additions & 0 deletions test/test_torch.py
Expand Up @@ -17272,6 +17272,60 @@ def test_reciprocal_complex_extremal(self, device, dtype):

self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@dtypes(torch.bool, *(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
def test_copysign(self, device, dtype):
def _test_copysign_numpy(a, b):
torch_result = torch.copysign(a, b)

if a.dtype == torch.bfloat16:
np_a = a.to(torch.float).cpu().numpy()
else:
np_a = a.cpu().numpy()

if b.dtype == torch.bfloat16:
np_b = b.to(torch.float).cpu().numpy()
else:
np_b = b.cpu().numpy()
expected = torch.from_numpy(np.copysign(np_a, np_b))

# To handle inconsistencies of type promotion between PyTorch and Numpy
# Applied for both arguments having integral precision and bfloat16
types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes()
if a.dtype in types or b.dtype in types:
promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
torch_result = torch_result.to(promoted_type)
expected = expected.to(promoted_type)

# Use double copysign to verify the correctnes of 0.0 and -0.0,
# since self.assertEqual(0.0 == -0.0)
self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result),
torch.copysign(torch.tensor(1.0), expected))

# Compare with NumPy
# Type promotion
type_list = [torch.bool] + torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()
for dtype2 in type_list:
a = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
b = make_tensor((10, 10), device=device, dtype=dtype2, low=-9, high=9)
_test_copysign_numpy(a, b)

# Broadcast
a = make_tensor((10, 1, 10), device=device, dtype=dtype, low=-9, high=9)
b = make_tensor((10, 10), device=device, dtype=dtype2, low=-9, high=9)
_test_copysign_numpy(a, b)

a = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
b = make_tensor((10, 1, 10), device=device, dtype=dtype2, low=-9, high=9)
_test_copysign_numpy(a, b)

# 0.0/-0.0/inf/-inf/nan
a = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9)
b_s = [0.0, -0.0, float('inf'), float('-inf'), float('nan')]
for dtype2 in torch.testing.get_all_fp_dtypes():
for b in b_s:
_test_copysign_numpy(a, torch.tensor(b, dtype=dtype2))

@dtypes(torch.bfloat16, torch.float)
def test_div(self, device, dtype):
for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_),
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -340,6 +340,13 @@
- name: _conj(Tensor self) -> Tensor
self: grad.conj()

- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other)

- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)

- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()

Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -8,6 +8,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/Dtype.h"
Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -2,6 +2,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
Expand Down
13 changes: 13 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -939,6 +939,19 @@ def add_docstr_all(method, docstr):
See :func:`torch.conj`
""")

add_docstr_all('copysign',
r"""
copysign(other) -> Tensor
See :func:`torch.copysign`
""")

add_docstr_all('copysign_', r"""
copysign_(other) -> Tensor
In-place version of :meth:`~Tensor.copysign`
""")

add_docstr_all('cos',
r"""
cos() -> Tensor
Expand Down
47 changes: 47 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -1903,6 +1903,53 @@ def merge_dicts(*dicts):
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
""".format(**common_args))

add_docstr(torch.copysign,
r"""
copysign(input, other, *, out=None) -> Tensor
Create a new float tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise.
.. math::
\text{out}_{i} = \begin{cases}
-|\text{input}_{i}| & \text{if} \text{other}_{i} \leq -0.0 \\
|\text{input}_{i}| & \text{if} \text{other}_{i} \geq 0.0 \\
\end{cases}
""" + r"""
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
and integer and float inputs.
Args:
input (Tensor): magnitudes.
other (Tensor or Number): contains value(s) whose signbit(s) are
applied to the magnitudes in :attr:`input`.
Keyword args:
{out}
Example::
>>> a = torch.randn(5)
>>> a
tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244])
>>> torch.copysign(a, 1)
tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244])
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.7079, 0.2778, -1.0249, 0.5719],
[-0.0059, -0.2600, -0.4475, -1.3948],
[ 0.3667, -0.9567, -2.5757, -0.1751],
[ 0.2046, -0.0742, 0.2998, -0.1054]])
>>> b = torch.randn(4)
tensor([ 0.2373, 0.3120, 0.3190, -1.1128])
>>> torch.copysign(a, b)
tensor([[ 0.7079, 0.2778, 1.0249, -0.5719],
[ 0.0059, 0.2600, 0.4475, -1.3948],
[ 0.3667, 0.9567, 2.5757, -0.1751],
[ 0.2046, 0.0742, 0.2998, -0.1054]])
""".format(**common_args))

add_docstr(torch.cos,
r"""
cos(input, *, out=None) -> Tensor
Expand Down

0 comments on commit fb57aba

Please sign in to comment.