Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,6 @@ TEST_F(TensorTest, TestSize) {
});
}

TEST_F(TensorTest, TestRelu) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This codegen removes the relu function in the tensor_methods, so would it be okay to remove these tests here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is oK

at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
at::Tensor output = input.relu();
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLATensorPtr dev_input = XLATensor::Create(input, device);
XLATensorPtr dev_output = XLATensor::relu(dev_input);
AllClose(output, dev_output);
});
}

TEST_F(TensorTest, TestRrelu) {
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
float lower = 0.125;
Expand Down
12 changes: 0 additions & 12 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2559,18 +2559,6 @@ at::Tensor XLANativeFunctions::reflection_pad2d_backward(
torch::lazy::ToVector<int64_t>(padding)));
}

at::Tensor XLANativeFunctions::relu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::relu(bridge::GetXlaTensor(self)));
}

at::Tensor& XLANativeFunctions::relu_(at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
XLATensor::relu_(self_tensor);
return self;
}

at::Tensor XLANativeFunctions::remainder(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
Expand Down
20 changes: 0 additions & 20 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,6 @@ torch::lazy::NodePtr SignOp(const torch::lazy::Value& input) {
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr ReluOp(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_output = BuildRelu(xla_input);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1) << "Unexpected number of operands";
return BuildRelu(operands[0]);
};
return GenericOp(torch::lazy::OpKind(at::aten::relu), {input},
[&]() {
return InferOutputShape({GetXlaShape(input)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr Prelu(const torch::lazy::Value& input,
const torch::lazy::Value& weight) {
auto lower_fn = [](const XlaNode& node,
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input);

torch::lazy::NodePtr SignOp(const torch::lazy::Value& input);

torch::lazy::NodePtr ReluOp(const torch::lazy::Value& input);

torch::lazy::NodePtr Min(const torch::lazy::Value& input,
const torch::lazy::Value& other);

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildReciprocal(xla_input), loctx);
}

torch_xla::XlaOpVector Relu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_output = BuildRelu(xla_input);
return ReturnOp(xla_output, loctx);
}

torch_xla::XlaOpVector Round::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::RoundToEven(xla_input), loctx);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1) << "Unexpected number of operands";
return BuildRelu(operands[0]);
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape RoundOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ xla::Shape MinimumOutputShape(const torch::lazy::Value& input,

xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input);

xla::Shape ReluOutputShape(const torch::lazy::Value& input);

xla::Shape RoundOutputShape(const torch::lazy::Value& input);

xla::Shape RsqrtOutputShape(const torch::lazy::Value& input);
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,6 @@ class XLATensor : public c10::intrusive_ptr_target {
const XLATensorPtr& input,
std::vector<int64_t> padding);

static XLATensorPtr relu(const XLATensorPtr& input);
static void relu_(XLATensorPtr& input);

static XLATensorPtr remainder(const XLATensorPtr& input,
const XLATensorPtr& other);
static XLATensorPtr remainder(const XLATensorPtr& input,
Expand Down
8 changes: 0 additions & 8 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2230,14 +2230,6 @@ XLATensorPtr XLATensor::reflection_pad2d_backward(
grad_output->GetIrValue(), input->GetIrValue(), std::move(padding)));
}

XLATensorPtr XLATensor::relu(const XLATensorPtr& input) {
return input->CreateFrom(ReluOp(input->GetIrValue()));
}

void XLATensor::relu_(XLATensorPtr& input) {
input->SetInPlaceIrValue(ReluOp(input->GetIrValue()));
}

XLATensorPtr XLATensor::remainder(const XLATensorPtr& input,
const XLATensorPtr& other) {
return input->CreateFrom(Remainder(input->GetIrValue(), other->GetIrValue()));
Expand Down
3 changes: 1 addition & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ full_codegen:
- maximum
- minimum
- reciprocal
- relu
- round
- rsqrt
- selu
Expand Down Expand Up @@ -253,8 +254,6 @@ supported:
- random_.to
- reflection_pad2d
- reflection_pad2d_backward
- relu
- relu_
- remainder.Scalar
- remainder.Tensor
- repeat
Expand Down