Skip to content

Commit

Permalink
Implement copysign (fix docs)
Browse files Browse the repository at this point in the history
ghstack-source-id: edccfde9095eab988ba254c001b784021fa6fe87
Pull Request resolved: #46396
  • Loading branch information
ejguan committed Oct 16, 2020
1 parent 7b788d1 commit 0fb92d9
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -239,6 +239,7 @@ _(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -48,6 +48,7 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -121,6 +122,23 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor copysign(const Tensor& self, Scalar other) {
return at::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -70,5 +70,6 @@ DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);

}} // namespace at::native
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -773,6 +773,14 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
});
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -812,6 +820,7 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);

} // namespace native
} // namespace at
35 changes: 35 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -108,6 +108,40 @@ void heaviside_kernel_cuda(TensorIterator& iter) {
});
}

template<typename scalar_t, typename accscalar_t>
struct CopySignScalarFunctor {
MulScalarFunctor(accscalar_t b_): b(b_) {}
__device__ scalar_t operator() (scalar_t a) const {
return ::copysign(a, b);
}
private:
accscalar_t b;
};

template<typename scalar_t>
struct CopySignFunctor {
__device__ scalar_t operator() (scalar_t a, scalar_t b) const {
return ::copysign(a, b);
}
};

void copysign_kernel_cuda(TensorIterator& iter) {
if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "copysign_cuda", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto b = iter.scalar_value<accscalar_t>(2);
iter.remove_operand(2);
CopySignScalarFunctor<scalar_t, decltype(b)> f(b);
gpu_kernel(iter, f);
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "copysign_cuda", [&]() {
CopySignFunctor<scalar_t> f;
gpu_kernel_with_scalars(iter, f);
});
}
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
Expand All @@ -118,5 +152,6 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel_cuda);

}} // namespace at::native
16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -843,6 +843,22 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU, CUDA: copysign

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: copysign

- func: logical_not(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -277,6 +277,7 @@ Pointwise Ops
clamp
clip
conj
copysign
cos
cosh
deg2rad
Expand Down
9 changes: 9 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6749,6 +6749,15 @@ def test_copy_(self, device):
z = x.to(torch.bfloat16)
self.assertTrue(z.requires_grad)

def test_copysign(self, device):
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float, device=device, requires_grad=True)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float, device=device).reshape(-1)
y.requires_grad_()
out = torch.copysign(x, y)
out.sum().backward()
self.assertEqual(x.grad.tolist(), [1., 0., 0., -1.] + [-1., 0., 0., 1.] * 3)
self.assertEqual(y.grad.tolist(), [0.] * 16)

@onlyCUDA
def test_simple_reentrant_cross_device(self, device):
class ReentrantFunc(Function):
Expand Down
60 changes: 60 additions & 0 deletions test/test_torch.py
Expand Up @@ -660,6 +660,66 @@ def test_copy_transpose(self):
self.assertEqual(y[:, 0], range(100))
self.assertEqual(y[:, 40], range(4000, 4100))

def test_copysign(self):
# Float or Integral promoted to Float
res = torch.tensor([-1, -0, -0, -1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1], dtype=torch.float)
types = [torch.short, torch.int, torch.long, torch.float]
for t in types:
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=t)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=t).reshape(-1)
z = torch.copysign(x, y)
self.assertEqual(z, res)

# Half
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.half)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.half).reshape(-1)
res = torch.tensor([-1, -0, -0, -1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1], dtype=torch.half)
self.assertEqual(torch.copysign(x, y), res)

# Bool
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.bool)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.bool).reshape(-1)
res = torch.tensor([1, 0, 0, 1] * 4, dtype=torch.float).reshape(-1)
self.assertEqual(torch.copysign(x, y), res)

# Broadcast
res = torch.tensor([[-1, -0, -0, -1], [1, 0, 0, 1], [1, 0, 0, 1], [1, 0, 0, 1]], dtype=torch.float)

# LHS
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(4, 1, 4)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float)
z = torch.copysign(x, y)
for i in range(4):
self.assertEqual(z[i], res)

x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(1, 4, 4)
z = torch.copysign(x, y)
self.assertEqual(z, res.reshape(1, 4, 4))

# RHS
x = torch.tensor([-1, 0, -0, 1] * 4, dtype=torch.float).reshape(4, 4)
y = torch.tensor([[-1] * 4, [-0] * 4, [0] * 4, [1] * 4], dtype=torch.float).reshape(4, 1, 4)
z = torch.copysign(x, y)
for i in range(4):
for j in range(4):
self.assertEqual(z[i][j], res[i])

# Scalar
x = torch.tensor([-1, 0, -0, 1], dtype=torch.float)
y = torch.tensor(-1.)
res = torch.tensor([-1, -0, -0, -1], dtype=torch.float)
self.assertEqual(torch.copysign(x, y), res)

x = torch.tensor(-1.)
y = torch.tensor([-1, 0, -0, 1], dtype=torch.float)
res = torch.tensor([-1, 1, 1, 1], dtype=torch.float)
self.assertEqual(torch.copysign(x, y), res)

# Constant
x = torch.tensor([-1, 0, -0, 1], dtype=torch.float)
res = torch.tensor([-1, -0, -0, -1], dtype=torch.float)
self.assertEqual(torch.copysign(x, -1.), res)

def test_device(self):
cpu = torch.device('cpu')
self.assertEqual('cpu', str(cpu))
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -340,6 +340,13 @@
- name: _conj(Tensor self) -> Tensor
self: grad.conj()

- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_self_backward(grad, self, other)
other: zeros_like(other)

- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_self_backward(grad, self, at::scalar_to_tensor(other))

- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()

Expand Down
49 changes: 49 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -1870,6 +1870,55 @@ def merge_dicts(*dicts):
tensor([-1 - 1j, -2 - 2j, 3 + 3j])
""".format(**common_args))

add_docstr(torch.copysign,
r"""
copysign(input, other, *, out=None) -> Tensor
Returns a new tensor with the magnitude of :attr:`input` and the sign of :attr:`other` elementwise.
.. math::
\text{out}_{i} = \begin{cases}
-|\text{input}_{i}| & \text{if } \text{other}_{i} < 0 \\
|\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0 \\
\end{cases}
""" + r"""
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
and integer, float inputs. Always promotes integral/boolean/half/float16 types to the float type.
Args:
input (Tensor): value with the magnitude of output value.
other (Tensor or Number): value with the sign of output value.
If the shape of :attr:`input` is different from the shape of
:attr:`other`, they must be broadcastable to a common shape
(which becomes the shape of the output).
Keyword args:
{out}
Example::
>>> a = torch.randn(5)
>>> a
tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244])
>>> torch.copysign(a, 1)
tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244])
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.7079, 0.2778, -1.0249, 0.5719],
[-0.0059, -0.2600, -0.4475, -1.3948],
[ 0.3667, -0.9567, -2.5757, -0.1751],
[ 0.2046, -0.0742, 0.2998, -0.1054]])
>>> b = torch.randn(4)
tensor([ 0.2373, 0.3120, 0.3190, -1.1128])
>>> torch.copysign(a, b)
tensor([[ 0.7079, 0.2778, 1.0249, -0.5719],
[ 0.0059, 0.2600, 0.4475, -1.3948],
[ 0.3667, 0.9567, 2.5757, -0.1751],
[ 0.2046, 0.0742, 0.2998, -0.1054]])
""".format(**common_args))

add_docstr(torch.cos,
r"""
cos(input, *, out=None) -> Tensor
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -49,6 +49,11 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
std::copy(t.begin(), t.end(), out.begin() + range.first);
}

Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & other) {
auto result = grad * self.sign() * (other.ge(0) * 2 - 1);
return result;
}

Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -32,6 +32,7 @@ struct IndexRangeGenerator {
bool any_variable_defined(variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor & t);
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<at::Tensor> t);
at::Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & other);
at::Tensor not_implemented(const char* name);
at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s);
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -285,6 +285,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.clone: lambda input: -1,
torch.combinations: lambda input, r=2, with_replacement=False: -1,
torch.complex: lambda real, imag: -1,
torch.copysign: lambda input, other, out=None: -1,
torch.polar: lambda abs, ang: -1,
torch.conj: lambda input, out=None: -1,
torch.constant_pad_nd: lambda input, pad, value=0: -1,
Expand Down

0 comments on commit 0fb92d9

Please sign in to comment.