diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d12955e501af..d01cb4facef9 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1405,21 +1405,22 @@ std::tuple XLANativeFunctions::kthvalue( bridge::AtenFromXlaTensor(std::get<1>(results))); } -at::Tensor XLANativeFunctions::leaky_relu(const at::Tensor& self, - const at::Scalar& negative_slope) { - TORCH_LAZY_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu( - bridge::GetXlaTensor(self), negative_slope.to())); -} - at::Tensor XLANativeFunctions::leaky_relu_backward( const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& negative_slope, bool self_is_result) { TORCH_LAZY_FN_COUNTER("xla::"); XLA_CHECK(!self_is_result || negative_slope.to() >= 0.0); - return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - negative_slope.to())); + auto common_device = torch_xla::bridge::GetXlaDevice(self); + XLA_CHECK(common_device); + auto node_negative_slope = + torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen( + negative_slope, *common_device); + torch::lazy::NodePtr node = torch::lazy::MakeNode( + bridge::GetXlaTensor(grad_output)->GetIrValue(), + bridge::GetXlaTensor(self)->GetIrValue(), node_negative_slope, + self_is_result); + return torch_xla::bridge::AtenFromXlaTensor( + torch_xla::XLATensor::Create(std::move(node), *common_device)); } at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 6971853fc5bf..090774cba55f 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -165,8 +165,8 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input, return xla::Select(Between(input, min_val, max_val), grad_output, zero); } -xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope_value) { - return BuildLeakyReluBackward(input, input, negative_slope_value); +xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope) { + return BuildLeakyReluBackward(input, input, negative_slope); } std::vector BuildRrelu(xla::XlaOp input, const at::Scalar& lower, @@ -188,7 +188,9 @@ std::vector BuildRrelu(xla::XlaOp input, const at::Scalar& lower, noise = xla::Select(xla::Gt(input, zero), one, slope); output = input * noise; } else { - double negative_slope = (lower.to() + upper.to()) / 2; + xla::XlaOp negative_slope = + XlaHelpers::ScalarValue((lower.to() + upper.to()) / 2, + shape.element_type(), input.builder()); noise = xla::Broadcast(zero, shape.dimensions()); output = BuildLeakyRelu(input, negative_slope); } @@ -214,11 +216,9 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input, } xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input, - double negative_slope_value) { + xla::XlaOp negative_slope) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type()); - xla::XlaOp negative_slope = XlaHelpers::ScalarValue( - negative_slope_value, input_shape.element_type(), input.builder()); return xla::Select(xla::Gt(input, zero), grad_output, negative_slope * grad_output); } diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 90533c268445..62f615135388 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -43,10 +43,10 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input, // Computes the leaky rectified linear unit: // LeakyReLU(x) = max(0, input) + negative_slope ∗ min(0, input). -xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope); +xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope); xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input, - double negative_slope_value); + xla::XlaOp negative_slope); // Computes the sigmoid function using Tanh // Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5 diff --git a/torch_xla/csrc/ops/leaky_relu.cpp b/torch_xla/csrc/ops/leaky_relu.cpp deleted file mode 100644 index 9a964b115a11..000000000000 --- a/torch_xla/csrc/ops/leaky_relu.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include "torch_xla/csrc/ops/leaky_relu.h" - -#include "torch_xla/csrc/elementwise.h" -#include "torch_xla/csrc/lowering_context.h" - -namespace torch_xla { - -LeakyRelu::LeakyRelu(const torch::lazy::Value& input, double negative_slope) - : XlaNode(torch::lazy::OpKind(at::aten::leaky_relu), {input}, - GetXlaShape(input), - /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), - negative_slope_(negative_slope) {} - -torch::lazy::NodePtr LeakyRelu::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0), negative_slope_); -} - -XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp output = BuildLeakyRelu(input, negative_slope_); - return ReturnOp(output, loctx); -} - -std::string LeakyRelu::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", negative_slope=" << negative_slope_; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/leaky_relu.h b/torch_xla/csrc/ops/leaky_relu.h deleted file mode 100644 index d5e574b0c72b..000000000000 --- a/torch_xla/csrc/ops/leaky_relu.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { - -class LeakyRelu : public XlaNode { - public: - LeakyRelu(const torch::lazy::Value& input, double negative_slope); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - double negative_slope() const { return negative_slope_; } - - private: - double negative_slope_; -}; - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/leaky_relu_backward.cpp b/torch_xla/csrc/ops/leaky_relu_backward.cpp deleted file mode 100644 index d5cced14a847..000000000000 --- a/torch_xla/csrc/ops/leaky_relu_backward.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "torch_xla/csrc/ops/leaky_relu_backward.h" - -#include "torch_xla/csrc/elementwise.h" -#include "torch_xla/csrc/lowering_context.h" - -namespace torch_xla { - -LeakyReluBackward::LeakyReluBackward(const torch::lazy::Value& grad_output, - const torch::lazy::Value& input, - double negative_slope) - : XlaNode(torch::lazy::OpKind(at::aten::leaky_relu_backward), - {grad_output, input}, GetXlaShape(input), - /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), - negative_slope_(negative_slope) {} - -torch::lazy::NodePtr LeakyReluBackward::Clone( - torch::lazy::OpList operands) const { - return torch::lazy::MakeNode( - operands.at(0), operands.at(1), negative_slope_); -} - -XlaOpVector LeakyReluBackward::Lower(LoweringContext* loctx) const { - xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); - xla::XlaOp input = loctx->GetOutputOp(operand(1)); - xla::XlaOp output = - BuildLeakyReluBackward(grad_output, input, negative_slope_); - return ReturnOp(output, loctx); -} - -std::string LeakyReluBackward::ToString() const { - std::stringstream ss; - ss << XlaNode::ToString() << ", negative_slope=" << negative_slope_; - return ss.str(); -} - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/leaky_relu_backward.h b/torch_xla/csrc/ops/leaky_relu_backward.h deleted file mode 100644 index df962d5d29b0..000000000000 --- a/torch_xla/csrc/ops/leaky_relu_backward.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include - -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { - -class LeakyReluBackward : public XlaNode { - public: - LeakyReluBackward(const torch::lazy::Value& grad_output, - const torch::lazy::Value& input, double negative_slope); - - torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - double negative_slope() const { return negative_slope_; } - - private: - double negative_slope_; -}; - -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 2c965285a084..5038fc913df2 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -404,6 +404,21 @@ torch_xla::XlaOpVector Isnan::Lower(LoweringContext* loctx) const { return ReturnOp(xla::IsNan(xla_input), loctx); } +torch_xla::XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + xla::XlaOp negative_slope = loctx->GetOutputOp(operand(1)); + return ReturnOp(BuildLeakyRelu(xla_input, negative_slope), loctx); +} + +torch_xla::XlaOpVector LeakyReluBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(operand(1)); + xla::XlaOp negative_slope = loctx->GetOutputOp(operand(2)); + return ReturnOp( + BuildLeakyReluBackward(xla_grad_output, xla_input, negative_slope), + loctx); +} + torch_xla::XlaOpVector Logdet::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::LogDet(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index bb2d6a210e85..15c9c43a3e4f 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -475,6 +475,30 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) { return isnan_shape; } +xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& negative_slope) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands"; + return BuildLeakyRelu(operands[0], operands[1]); + }; + return InferOutputShape({GetXlaShape(input), GetXlaShape(negative_slope)}, + lower_for_shape_fn); +} + +xla::Shape LeakyReluBackwardOutputShape( + const torch::lazy::Value& grad_output, const torch::lazy::Value& input, + const torch::lazy::Value& negative_slope, bool self_is_result) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands"; + return BuildLeakyReluBackward(operands[0], operands[1], operands[2]); + }; + return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input), + GetXlaShape(negative_slope)}, + lower_for_shape_fn); +} + xla::Shape LeScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 9d8778fb7df1..f1989c5e8875 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -161,6 +161,13 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input); xla::Shape IsnanOutputShape(const torch::lazy::Value& input); +xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input, + const torch::lazy::Value& negative_slope); + +xla::Shape LeakyReluBackwardOutputShape( + const torch::lazy::Value& grad_output, const torch::lazy::Value& input, + const torch::lazy::Value& negative_slope, bool self_is_result); + xla::Shape LeScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ae622e2f552a..2fa581e8c018 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -63,8 +63,6 @@ #include "torch_xla/csrc/ops/index_select.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/kth_value.h" -#include "torch_xla/csrc/ops/leaky_relu.h" -#include "torch_xla/csrc/ops/leaky_relu_backward.h" #include "torch_xla/csrc/ops/linear_interpolation.h" #include "torch_xla/csrc/ops/linspace.h" #include "torch_xla/csrc/ops/log_softmax.h" @@ -1407,18 +1405,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output, grad_output->GetIrValue(), input->GetIrValue(), min_val, max_val)); } -XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope) { - return input->CreateFrom( - torch::lazy::MakeNode(input->GetIrValue(), negative_slope)); -} - -XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output, - const XLATensorPtr& input, - double negative_slope) { - return grad_output->CreateFrom(torch::lazy::MakeNode( - grad_output->GetIrValue(), input->GetIrValue(), negative_slope)); -} - XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, const XLATensorPtr& weight) { return input->CreateFrom( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index a0c8e01817c0..a4a073a6b645 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -444,11 +444,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output, const at::Scalar& min_val, const at::Scalar& max_val); -XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope); -XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output, - const XLATensorPtr& input, - double negative_slope); - XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, const XLATensorPtr& weight); XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 727bc9076d3e..a503a0576030 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -48,6 +48,7 @@ full_codegen: - hardswish_backward - inverse - isnan + - leaky_relu - le.Scalar - le.Tensor - logdet @@ -92,6 +93,7 @@ ir_gen: - bitwise_and.Tensor - bitwise_or.Tensor - bitwise_xor.Tensor + - leaky_relu_backward supported: - __ilshift__.Scalar - __ilshift__.Tensor @@ -192,7 +194,6 @@ supported: - index_select - kl_div - kthvalue - - leaky_relu - leaky_relu_backward - lerp.Scalar - lerp.Tensor