Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sign: port to structured #57588

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 7 additions & 8 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -102,6 +102,12 @@ TORCH_META_FUNC(floor) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}

TORCH_META_FUNC(sign) (const Tensor& self) {
TORCH_CHECK(!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
build_unary_op(maybe_get_output(), self);
}

} // namespace meta

namespace native {
Expand Down Expand Up @@ -144,6 +150,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal)
CREATE_UNARY_TORCH_IMPL_FUNC(round)
CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt)
CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid)
CREATE_UNARY_TORCH_IMPL_FUNC(sign)
CREATE_UNARY_TORCH_IMPL_FUNC(sin)
CREATE_UNARY_TORCH_IMPL_FUNC(sinc)
CREATE_UNARY_TORCH_IMPL_FUNC(sinh)
Expand Down Expand Up @@ -402,14 +409,6 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); }
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }

Tensor& sign_out(const Tensor& self, Tensor& result) {
TORCH_CHECK(!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
return unary_op_impl_out(result, self, sign_stub);
}
Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); }
Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); }

Tensor& sgn_out(const Tensor& self, Tensor& result) {
if (self.is_complex()) {
return unary_op_impl_out(result, self, sgn_stub);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6435,18 +6435,22 @@

- func: sign(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: sign.out
variants: function, method
dispatch:
CompositeExplicitAutograd: sign

- func: sign_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured_delegate: sign.out
variants: method
dispatch:
CompositeExplicitAutograd: sign_

- func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: sign_out

Expand Down