Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,29 +934,23 @@ TEST_F(AtenXlaTensorTest, TestQR) {
}
}

TEST_F(AtenXlaTensorTest, TestSymEig) {
TEST_F(AtenXlaTensorTest, TestLinalgEigh) {
static const int dims[] = {4, 7};
for (auto m : dims) {
for (bool eigenvectors : {true, false}) {
for (bool upper : {true, false}) {
torch::Tensor a =
torch::rand({m, m}, torch::TensorOptions(torch::kFloat));
torch::Tensor sym_a = a.mm(a.t());
auto b = torch::symeig(sym_a, eigenvectors, upper);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(sym_a, device);
auto xla_b = torch::symeig(xla_a, eigenvectors, upper);
AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2,
/*atol=*/1e-2);
if (eigenvectors) {
AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(),
/*rtol=*/3e-2,
/*atol=*/1e-2);
} else {
EXPECT_EQ(std::get<1>(b).sizes(), std::get<1>(xla_b).sizes());
}
});
}
for (std::string uplo : {"U", "L"}) {
torch::Tensor a =
torch::rand({m, m}, torch::TensorOptions(torch::kFloat));
torch::Tensor sym_a = a.mm(a.t());
auto b = torch::linalg_eigh(sym_a, uplo);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(sym_a, device);
auto xla_b = torch::linalg_eigh(xla_a, uplo);
AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2,
/*atol=*/1e-2);
AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(),
/*rtol=*/3e-2,
/*atol=*/1e-2);
});
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2777,11 +2777,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::svd(
bridge::AtenFromXlaTensor(std::get<2>(results)));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::symeig(
const at::Tensor& self, bool eigenvectors, bool upper) {
std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::linalg_eigh(
const at::Tensor& self, c10::string_view uplo) {
TORCH_LAZY_FN_COUNTER("xla::");
auto results =
tensor_methods::symeig(bridge::GetXlaTensor(self), eigenvectors, upper);
auto results = tensor_methods::linalg_eigh(bridge::GetXlaTensor(self), uplo);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)),
bridge::AtenFromXlaTensor(std::get<1>(results)));
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/symeig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, bool eigenvectors,
} // namespace

SymEig::SymEig(const torch::lazy::Value& input, bool eigenvectors, bool lower)
: XlaNode(torch::lazy::OpKind(at::aten::symeig), {input},
: XlaNode(torch::lazy::OpKind(at::aten::linalg_eigh), {input},
[&]() { return NodeOutputShape(input, eigenvectors, lower); },
/*num_outputs=*/2, torch::lazy::MHash(eigenvectors, lower)),
eigenvectors_(eigenvectors),
Expand Down
21 changes: 12 additions & 9 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/core/Reduction.h>

#include <algorithm>
#include <cctype>
#include <functional>

#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -2387,6 +2388,17 @@ XLATensorPtr sub(const XLATensorPtr& input, const XLATensorPtr& other,
logical_element_type);
}

std::tuple<XLATensorPtr, XLATensorPtr> XLATensor::linalg_eigh(
const XLATensorPtr& input, c10::string_view uplo_str) {
// SymEig takes lower instead of upper, hence the negation.
char uplo = std::toupper(uplo_str[0]);
bool lower = (uplo == 'L');
torch::lazy::NodePtr node = torch::lazy::MakeNode<torch::lazy::ops::SymEig>(
input->GetIrValue(), /*eigenvectors=*/true, lower);
return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1)));
}

XLATensorPtr sub(const XLATensorPtr& input, const at::Scalar& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
Expand Down Expand Up @@ -2427,15 +2439,6 @@ std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> svd(
input->CreateFrom(torch::lazy::Value(node, 2)));
}

std::tuple<XLATensorPtr, XLATensorPtr> symeig(const XLATensorPtr& input,
bool eigenvectors, bool upper) {
// SymEig takes lower instead of upper, hence the negation.
torch::lazy::NodePtr node =
torch::lazy::MakeNode<SymEig>(input->GetIrValue(), eigenvectors, !upper);
return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1)));
}

XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& output) {
return mul(grad_output, rsub(pow(output, 2), 1, 1));
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
const at::Scalar& weight);

std::tuple<XLATensorPtr, XLATensorPtr> linalg_eigh(const XLATensorPtr& input,
c10::string_view uplo);

XLATensorPtr linspace(const at::Scalar& start, const at::Scalar& end,
const int64_t steps, at::ScalarType element_type,
const torch::lazy::BackendDevice& device);
Expand Down Expand Up @@ -805,9 +808,6 @@ XLATensorPtr sum(const XLATensorPtr& input, std::vector<int64_t> dimensions,
std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> svd(
const XLATensorPtr& input, bool some, bool compute_uv);

std::tuple<XLATensorPtr, XLATensorPtr> symeig(const XLATensorPtr& input,
bool eigenvectors, bool upper);

XLATensorPtr take(const XLATensorPtr& input, const XLATensorPtr& index);

XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
Expand Down
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ supported:
- leaky_relu_backward
- lerp.Scalar
- lerp.Tensor
- linalg_eigh
- linspace
- log
- log1p
Expand Down Expand Up @@ -307,7 +308,6 @@ supported:
- sum
- sum.dim_IntList
- svd
- symeig
- t
- t_
- tanh_backward
Expand Down