Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,21 +1405,22 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::kthvalue(
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

at::Tensor XLANativeFunctions::leaky_relu(const at::Tensor& self,
const at::Scalar& negative_slope) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu(
bridge::GetXlaTensor(self), negative_slope.to<double>()));
}

at::Tensor XLANativeFunctions::leaky_relu_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Scalar& negative_slope, bool self_is_result) {
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(!self_is_result || negative_slope.to<double>() >= 0.0);
return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
negative_slope.to<double>()));
auto common_device = torch_xla::bridge::GetXlaDevice(self);
XLA_CHECK(common_device);
auto node_negative_slope =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
negative_slope, *common_device);
torch::lazy::NodePtr node = torch::lazy::MakeNode<LeakyReluBackward>(
bridge::GetXlaTensor(grad_output)->GetIrValue(),
bridge::GetXlaTensor(self)->GetIrValue(), node_negative_slope,
self_is_result);
return torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
}

at::Tensor XLANativeFunctions::lerp(const at::Tensor& self,
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,
return xla::Select(Between(input, min_val, max_val), grad_output, zero);
}

xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope_value) {
return BuildLeakyReluBackward(input, input, negative_slope_value);
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope) {
return BuildLeakyReluBackward(input, input, negative_slope);
}

std::vector<xla::XlaOp> BuildRrelu(xla::XlaOp input, const at::Scalar& lower,
Expand All @@ -188,7 +188,9 @@ std::vector<xla::XlaOp> BuildRrelu(xla::XlaOp input, const at::Scalar& lower,
noise = xla::Select(xla::Gt(input, zero), one, slope);
output = input * noise;
} else {
double negative_slope = (lower.to<double>() + upper.to<double>()) / 2;
xla::XlaOp negative_slope =
XlaHelpers::ScalarValue((lower.to<double>() + upper.to<double>()) / 2,
shape.element_type(), input.builder());
noise = xla::Broadcast(zero, shape.dimensions());
output = BuildLeakyRelu(input, negative_slope);
}
Expand All @@ -214,11 +216,9 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input,
}

xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
double negative_slope_value) {
xla::XlaOp negative_slope) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type());
xla::XlaOp negative_slope = XlaHelpers::ScalarValue(
negative_slope_value, input_shape.element_type(), input.builder());
return xla::Select(xla::Gt(input, zero), grad_output,
negative_slope * grad_output);
}
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,

// Computes the leaky rectified linear unit:
// LeakyReLU(x) = max(0, input) + negative_slope ∗ min(0, input).
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope);
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope);

xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
double negative_slope_value);
xla::XlaOp negative_slope);

// Computes the sigmoid function using Tanh
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
Expand Down
30 changes: 0 additions & 30 deletions torch_xla/csrc/ops/leaky_relu.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions torch_xla/csrc/ops/leaky_relu.h

This file was deleted.

36 changes: 0 additions & 36 deletions torch_xla/csrc/ops/leaky_relu_backward.cpp

This file was deleted.

26 changes: 0 additions & 26 deletions torch_xla/csrc/ops/leaky_relu_backward.h

This file was deleted.

15 changes: 15 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,21 @@ torch_xla::XlaOpVector Isnan::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::IsNan(xla_input), loctx);
}

torch_xla::XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildLeakyRelu(xla_input, negative_slope), loctx);
}

torch_xla::XlaOpVector LeakyReluBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(2));
return ReturnOp(
BuildLeakyReluBackward(xla_grad_output, xla_input, negative_slope),
loctx);
}

torch_xla::XlaOpVector Logdet::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::LogDet(xla_input), loctx);
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,30 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) {
return isnan_shape;
}

xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
return BuildLeakyRelu(operands[0], operands[1]);
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(negative_slope)},
lower_for_shape_fn);
}

xla::Shape LeakyReluBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope, bool self_is_result) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands";
return BuildLeakyReluBackward(operands[0], operands[1], operands[2]);
};
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input),
GetXlaShape(negative_slope)},
lower_for_shape_fn);
}

xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);

xla::Shape IsnanOutputShape(const torch::lazy::Value& input);

xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope);

xla::Shape LeakyReluBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope, bool self_is_result);

xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

Expand Down
14 changes: 0 additions & 14 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
#include "torch_xla/csrc/ops/index_select.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/kth_value.h"
#include "torch_xla/csrc/ops/leaky_relu.h"
#include "torch_xla/csrc/ops/leaky_relu_backward.h"
#include "torch_xla/csrc/ops/linear_interpolation.h"
#include "torch_xla/csrc/ops/linspace.h"
#include "torch_xla/csrc/ops/log_softmax.h"
Expand Down Expand Up @@ -1407,18 +1405,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output,
grad_output->GetIrValue(), input->GetIrValue(), min_val, max_val));
}

XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope) {
return input->CreateFrom(
torch::lazy::MakeNode<LeakyRelu>(input->GetIrValue(), negative_slope));
}

XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
double negative_slope) {
return grad_output->CreateFrom(torch::lazy::MakeNode<LeakyReluBackward>(
grad_output->GetIrValue(), input->GetIrValue(), negative_slope));
}

XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
const XLATensorPtr& weight) {
return input->CreateFrom(
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output,
const at::Scalar& min_val,
const at::Scalar& max_val);

XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope);
XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
double negative_slope);

XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
const XLATensorPtr& weight);
XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
Expand Down
3 changes: 2 additions & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ full_codegen:
- hardswish_backward
- inverse
- isnan
- leaky_relu
- le.Scalar
- le.Tensor
- logdet
Expand Down Expand Up @@ -92,6 +93,7 @@ ir_gen:
- bitwise_and.Tensor
- bitwise_or.Tensor
- bitwise_xor.Tensor
- leaky_relu_backward
supported:
- __ilshift__.Scalar
- __ilshift__.Tensor
Expand Down Expand Up @@ -192,7 +194,6 @@ supported:
- index_select
- kl_div
- kthvalue
- leaky_relu
- leaky_relu_backward
- lerp.Scalar
- lerp.Tensor
Expand Down