From 1096b523dee364a027df93eeda7719b399786d20 Mon Sep 17 00:00:00 2001 From: Freey0 Date: Wed, 5 May 2021 10:38:10 +0800 Subject: [PATCH] Port sign to structured [ghstack-poisoned] --- aten/src/ATen/native/UnaryOps.cpp | 15 +++++++-------- aten/src/ATen/native/native_functions.yaml | 4 ++++ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index b027167046563b9..a09b1ead0cb5439 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -102,6 +102,12 @@ TORCH_META_FUNC(floor) (const Tensor& self) { build_unary_op(maybe_get_output(), self); } +TORCH_META_FUNC(sign) (const Tensor& self) { + TORCH_CHECK(!self.is_complex(), + "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); + build_unary_op(maybe_get_output(), self); +} + } // namespace meta namespace native { @@ -144,6 +150,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal) CREATE_UNARY_TORCH_IMPL_FUNC(round) CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt) CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid) +CREATE_UNARY_TORCH_IMPL_FUNC(sign) CREATE_UNARY_TORCH_IMPL_FUNC(sin) CREATE_UNARY_TORCH_IMPL_FUNC(sinc) CREATE_UNARY_TORCH_IMPL_FUNC(sinh) @@ -402,14 +409,6 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); } Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); } Tensor special_erfinv(const Tensor& self) { return self.erfinv(); } -Tensor& sign_out(const Tensor& self, Tensor& result) { - TORCH_CHECK(!self.is_complex(), - "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); - return unary_op_impl_out(result, self, sign_stub); -} -Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); } -Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); } - Tensor& sgn_out(const Tensor& self, Tensor& result) { if (self.is_complex()) { return unary_op_impl_out(result, self, sgn_stub); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cfb492e3c959448..5d97f1687de97c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6435,18 +6435,22 @@ - func: sign(Tensor self) -> Tensor device_check: NoCheck # TensorIterator + structured_delegate: sign.out variants: function, method dispatch: CompositeExplicitAutograd: sign - func: sign_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured_delegate: sign.out variants: method dispatch: CompositeExplicitAutograd: sign_ - func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: sign_out