Skip to content

Commit

Permalink
Implement copysign (Fix bugs)
Browse files Browse the repository at this point in the history
ghstack-source-id: ad138297abbe2e65bfb661c8446356e003ab152c
Pull Request resolved: #46396
  • Loading branch information
ejguan committed Oct 15, 2020
1 parent b64cf93 commit 1fbd8ab
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -239,6 +239,7 @@ _(aten, combinations) \
_(aten, _conj) \
_(aten, conj) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
_(aten, constant_pad_nd) \
_(aten, contiguous) \
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -48,6 +48,7 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);
DEFINE_DISPATCH(copysign_stub);

static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
Expand Down Expand Up @@ -121,6 +122,23 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor copysign(const Tensor& self, Scalar other) {
return at::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_float_op(result, self, other);
div_stub(iter.device_type(), iter);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -70,5 +70,6 @@ DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);

}} // namespace at::native
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -773,6 +773,13 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t { return std::copysign(a, b); });
});
}

} // namespace

REGISTER_DISPATCH(add_stub, &add_kernel);
Expand Down Expand Up @@ -812,6 +819,7 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);

} // namespace native
} // namespace at
9 changes: 9 additions & 0 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Expand Up @@ -108,6 +108,14 @@ void heaviside_kernel_cuda(TensorIterator& iter) {
});
}

void copysign_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "copysign_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return ::copysign(a, b);
});
});
}

REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_cuda);
REGISTER_DISPATCH(mse_stub, &mse_kernel_cuda);
Expand All @@ -118,5 +126,6 @@ REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel_cuda);

}} // namespace at::native
14 changes: 14 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -791,6 +791,20 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
dispatch:
CPU, CUDA: copysign

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: function

- func: logical_not(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
7 changes: 7 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -340,6 +340,13 @@
- name: _conj(Tensor self) -> Tensor
self: grad.conj()

- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
self: copysign_tensor_backward(grad, self, other)
other: non_differentiable

- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
self: copysign_tensor_backward(grad, self, at::scalar_to_tensor(other))

- name: cos(Tensor self) -> Tensor
self: grad * -self.sin().conj()

Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -49,6 +49,11 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
std::copy(t.begin(), t.end(), out.begin() + range.first);
}

Tensor copysign_tensor_backward(Tensor grad, Tensor self, Tensor other) {
auto result = grad * self.sign() * other.sign();
return result;
}

Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -32,6 +32,7 @@ struct IndexRangeGenerator {
bool any_variable_defined(variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor & t);
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<at::Tensor> t);
at::Tensor copysign_tensor_backward(Tensor grad, Tensor self, Tensor other);
at::Tensor not_implemented(const char* name);
at::Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result);
at::Tensor maybe_multiply(const at::Tensor & t, const at::Scalar & s);
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -288,6 +288,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.clone: lambda input: -1,
torch.combinations: lambda input, r=2, with_replacement=False: -1,
torch.complex: lambda real, imag: -1,
torch.copysign: lambda input, other, out=None: -1,
torch.polar: lambda abs, ang: -1,
torch.conj: lambda input, out=None: -1,
torch.constant_pad_nd: lambda input, pad, value=0: -1,
Expand Down

0 comments on commit 1fbd8ab

Please sign in to comment.