diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index c148dd072e6f..1fb793b6d745 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -119,16 +119,6 @@ TEST_F(TensorTest, TestSize) { }); } -TEST_F(TensorTest, TestRelu) { - at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat)); - at::Tensor output = input.relu(); - ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLATensorPtr dev_input = XLATensor::Create(input, device); - XLATensorPtr dev_output = XLATensor::relu(dev_input); - AllClose(output, dev_output); - }); -} - TEST_F(TensorTest, TestRrelu) { at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat)); float lower = 0.125; diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 88b0abe887ce..0a97a1ddfd83 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2559,18 +2559,6 @@ at::Tensor XLANativeFunctions::reflection_pad2d_backward( torch::lazy::ToVector(padding))); } -at::Tensor XLANativeFunctions::relu(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::relu(bridge::GetXlaTensor(self))); -} - -at::Tensor& XLANativeFunctions::relu_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensor::relu_(self_tensor); - return self; -} - at::Tensor XLANativeFunctions::remainder(const at::Tensor& self, const at::Tensor& other) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 8ee6ac6abfeb..01ea9e6c9430 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -127,26 +127,6 @@ torch::lazy::NodePtr SignOp(const torch::lazy::Value& input) { GetXlaShape(input), std::move(lower_fn)); } -torch::lazy::NodePtr ReluOp(const torch::lazy::Value& input) { - auto lower_fn = [](const XlaNode& node, - LoweringContext* loctx) -> XlaOpVector { - xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); - xla::XlaOp xla_output = BuildRelu(xla_input); - return node.ReturnOp(xla_output, loctx); - }; - auto lower_for_shape_fn = - [](absl::Span operands) -> xla::XlaOp { - XLA_CHECK_EQ(operands.size(), 1) << "Unexpected number of operands"; - return BuildRelu(operands[0]); - }; - return GenericOp(torch::lazy::OpKind(at::aten::relu), {input}, - [&]() { - return InferOutputShape({GetXlaShape(input)}, - lower_for_shape_fn); - }, - std::move(lower_fn)); -} - torch::lazy::NodePtr Prelu(const torch::lazy::Value& input, const torch::lazy::Value& weight) { auto lower_fn = [](const XlaNode& node, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index ef5e23d75a3a..03257dd49eb2 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -86,8 +86,6 @@ torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input); torch::lazy::NodePtr SignOp(const torch::lazy::Value& input); -torch::lazy::NodePtr ReluOp(const torch::lazy::Value& input); - torch::lazy::NodePtr Min(const torch::lazy::Value& input, const torch::lazy::Value& other); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 8b84d2722d09..6ae033753f29 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -211,6 +211,12 @@ torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const { return ReturnOp(BuildReciprocal(xla_input), loctx); } +torch_xla::XlaOpVector Relu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_output = BuildRelu(xla_input); + return ReturnOp(xla_output, loctx); +} + torch_xla::XlaOpVector Round::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::RoundToEven(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index c497f7d7dc76..bcb80ab54bdc 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -207,6 +207,15 @@ xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape ReluOutputShape(const torch::lazy::Value& input) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 1) << "Unexpected number of operands"; + return BuildRelu(operands[0]); + }; + return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); +} + xla::Shape RoundOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 3ee8aeb2c6c7..f687fb4420ed 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -82,6 +82,8 @@ xla::Shape MinimumOutputShape(const torch::lazy::Value& input, xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input); +xla::Shape ReluOutputShape(const torch::lazy::Value& input); + xla::Shape RoundOutputShape(const torch::lazy::Value& input); xla::Shape RsqrtOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 2ccee7c82cef..465c893996ca 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -966,9 +966,6 @@ class XLATensor : public c10::intrusive_ptr_target { const XLATensorPtr& input, std::vector padding); - static XLATensorPtr relu(const XLATensorPtr& input); - static void relu_(XLATensorPtr& input); - static XLATensorPtr remainder(const XLATensorPtr& input, const XLATensorPtr& other); static XLATensorPtr remainder(const XLATensorPtr& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ce072f05046d..26641a769762 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2230,14 +2230,6 @@ XLATensorPtr XLATensor::reflection_pad2d_backward( grad_output->GetIrValue(), input->GetIrValue(), std::move(padding))); } -XLATensorPtr XLATensor::relu(const XLATensorPtr& input) { - return input->CreateFrom(ReluOp(input->GetIrValue())); -} - -void XLATensor::relu_(XLATensorPtr& input) { - input->SetInPlaceIrValue(ReluOp(input->GetIrValue())); -} - XLATensorPtr XLATensor::remainder(const XLATensorPtr& input, const XLATensorPtr& other) { return input->CreateFrom(Remainder(input->GetIrValue(), other->GetIrValue())); diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 81e2322e37c9..bca9a11a978a 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -33,6 +33,7 @@ full_codegen: - maximum - minimum - reciprocal + - relu - round - rsqrt - selu @@ -253,8 +254,6 @@ supported: - random_.to - reflection_pad2d - reflection_pad2d_backward - - relu - - relu_ - remainder.Scalar - remainder.Tensor - repeat