Skip to content

Commit

Permalink
Implement copysign (Minor Fix)
Browse files Browse the repository at this point in the history
ghstack-source-id: caeb4463c61473183ac336f7d6eecb4fd05483c0
Pull Request resolved: #46396
  • Loading branch information
ejguan committed Nov 2, 2020
1 parent 1cc1da5 commit 63d382a
Show file tree
Hide file tree
Showing 20 changed files with 283 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -239,6 +239,7 @@ _(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -49,6 +49,7 @@ DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(igamma_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -122,6 +123,31 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& copysign_(Tensor& self, const Tensor& other) {
return native::copysign_out(self, self, other);
}

Tensor copysign(const Tensor& self, Scalar other) {
return native::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& copysign_(Tensor& self, Scalar other) {
return native::copysign_(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -71,5 +71,6 @@ DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, igamma_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);

}} // namespace at::native
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -800,6 +800,14 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
});
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -840,6 +848,7 @@ REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(igamma_stub, &igamma_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);

} // namespace native
} // namespace at
18 changes: 18 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -5,6 +5,15 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>

#if defined(__CUDACC__)
#include <cuda.h>
#include <cuda_fp16.h>
#include <c10/cuda/CUDAMathCompat.h>
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <c10/hip/HIPMathCompat.h>
#endif

// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
Expand Down Expand Up @@ -116,6 +125,14 @@ void heaviside_kernel_cuda(TensorIterator& iter) {
});
}

void copysign_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return c10::cuda::compat::copysign(a, b);
});
});
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
Expand All @@ -127,5 +144,6 @@ REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel_cuda);

}} // namespace at::native
28 changes: 28 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -843,6 +843,34 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign

- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: copysign_

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: copysign

- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
dispatch:
CPU, CUDA: copysign_

- func: logical_not(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
7 changes: 7 additions & 0 deletions c10/cuda/CUDAMathCompat.h
Expand Up @@ -42,6 +42,13 @@ __MATH_FUNCTIONS_DECL__ double ceil(double x) {
return ::ceil(x);
}

__MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
return ::copysignf(x, y);
}
__MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
return ::copysign(x, y);
}

__MATH_FUNCTIONS_DECL__ float floor(float x) {
return ::floorf(x);
}
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tensors.rst
Expand Up @@ -249,6 +249,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: contiguous
.. automethod:: copy_
.. automethod:: conj
.. automethod:: copysign
.. automethod:: copysign_
.. automethod:: cos
.. automethod:: cos_
.. automethod:: cosh
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -279,6 +279,7 @@ Pointwise Ops
clamp
clip
conj
copysign
cos
cosh
deg2rad
Expand Down
33 changes: 33 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6713,6 +6713,39 @@ def test_GRU_grad_and_gradgrad(self, device):
mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(torch.float64)
self._test_rnn_mod(mod, inp)

def test_copysign_subgradient(self, device):
# Input is 0.0
x = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Input is -0.0
x = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Other is 0.0
x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

# Other is -0.0
x = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([-0.0, -0.0, -0.0], dtype=torch.float, device=device, requires_grad=True)
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0])
self.assertEqual(y.grad.tolist(), [0.0] * 3)

@deviceCountAtLeast(1)
def test_grad_assignment(self, devices):
x = torch.randn(5, 5, device=devices[0])
Expand Down
61 changes: 61 additions & 0 deletions test/test_torch.py
Expand Up @@ -17292,6 +17292,67 @@ def test_reciprocal_complex_extremal(self, device, dtype):

self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@dtypes(*product(torch.testing.get_all_dtypes(include_complex=False),
torch.testing.get_all_dtypes(include_complex=False)))
def test_copysign(self, device, dtypes):
def _test_copysign_numpy(a, b):
torch_result = torch.copysign(a, b)

if a.dtype == torch.bfloat16:
np_a = a.to(torch.float).cpu().numpy()
else:
np_a = a.cpu().numpy()

if b.dtype == torch.bfloat16:
np_b = b.to(torch.float).cpu().numpy()
else:
np_b = b.cpu().numpy()
expected = torch.from_numpy(np.copysign(np_a, np_b))
# To handle inconsistencies of type promotion between PyTorch and Numpy
# Applied for both arguments having integral precision and bfloat16
types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes()
if a.dtype in types or b.dtype in types:
promoted_type = torch.promote_types(torch_result.dtype, expected.dtype)
torch_result = torch_result.to(promoted_type)
expected = expected.to(promoted_type)

# Verify Value
self.assertEqual(torch_result, expected)
# Verify Sign
# Use double copysign to verify the correctnes of 0.0 and -0.0, since
# it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the
# magnitude to verify the sign between torch and numpy results, elementwise.
self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result),
torch.copysign(torch.tensor(1.0), expected))

# Compare Result with NumPy
# Type promotion
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
_test_copysign_numpy(a, b)

# Broadcast
a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9)
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
_test_copysign_numpy(a, b)

a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9)
_test_copysign_numpy(a, b)

# 0.0/-0.0/inf/-inf/nan
cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')]
if dtypes[0] in torch.testing.get_all_fp_dtypes(include_bfloat16=False):
b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9)
for case in cases:
_test_copysign_numpy(torch.tensor([case], dtype=dtypes[0]), b)

if dtypes[1] in torch.testing.get_all_fp_dtypes():
a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9)
for case in cases:
_test_copysign_numpy(a, torch.tensor([case], dtype=dtypes[1]))

@dtypes(torch.bfloat16, torch.float)
def test_div(self, device, dtype):
for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_),
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -340,6 +340,13 @@
- name: _conj(Tensor self) -> Tensor
self: grad.conj()

- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)
other: zeros_like(other)

- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_self_backward(grad, self, result)

- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()

Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_torch_functions.cpp
Expand Up @@ -8,6 +8,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/Dtype.h"
Expand Down
6 changes: 6 additions & 0 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -2,6 +2,12 @@

#include <Python.h>

// Undefine the copysign macro so that at::copysign works as intended with MSVC
// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
#ifdef _MSC_VER
#undef copysign
#endif // _MSC_VER

#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
Expand Down
13 changes: 13 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -939,6 +939,19 @@ def add_docstr_all(method, docstr):
See :func:`torch.conj`
""")

add_docstr_all('copysign',
r"""
copysign(other) -> Tensor
See :func:`torch.copysign`
""")

add_docstr_all('copysign_', r"""
copysign_(other) -> Tensor
In-place version of :meth:`~Tensor.copysign`
""")

add_docstr_all('cos',
r"""
cos() -> Tensor
Expand Down

0 comments on commit 63d382a

Please sign in to comment.