diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 3c78805618eb..3064f2c95416 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -934,29 +934,23 @@ TEST_F(AtenXlaTensorTest, TestQR) { } } -TEST_F(AtenXlaTensorTest, TestSymEig) { +TEST_F(AtenXlaTensorTest, TestLinalgEigh) { static const int dims[] = {4, 7}; for (auto m : dims) { - for (bool eigenvectors : {true, false}) { - for (bool upper : {true, false}) { - torch::Tensor a = - torch::rand({m, m}, torch::TensorOptions(torch::kFloat)); - torch::Tensor sym_a = a.mm(a.t()); - auto b = torch::symeig(sym_a, eigenvectors, upper); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(sym_a, device); - auto xla_b = torch::symeig(xla_a, eigenvectors, upper); - AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2, - /*atol=*/1e-2); - if (eigenvectors) { - AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), - /*rtol=*/3e-2, - /*atol=*/1e-2); - } else { - EXPECT_EQ(std::get<1>(b).sizes(), std::get<1>(xla_b).sizes()); - } - }); - } + for (std::string uplo : {"U", "L"}) { + torch::Tensor a = + torch::rand({m, m}, torch::TensorOptions(torch::kFloat)); + torch::Tensor sym_a = a.mm(a.t()); + auto b = torch::linalg_eigh(sym_a, uplo); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(sym_a, device); + auto xla_b = torch::linalg_eigh(xla_a, uplo); + AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2, + /*atol=*/1e-2); + AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), + /*rtol=*/3e-2, + /*atol=*/1e-2); + }); } } } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9580fba5f606..08383af6243f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2777,11 +2777,10 @@ std::tuple XLANativeFunctions::svd( bridge::AtenFromXlaTensor(std::get<2>(results))); } -std::tuple XLANativeFunctions::symeig( - const at::Tensor& self, bool eigenvectors, bool upper) { +std::tuple XLANativeFunctions::linalg_eigh( + const at::Tensor& self, c10::string_view uplo) { TORCH_LAZY_FN_COUNTER("xla::"); - auto results = - tensor_methods::symeig(bridge::GetXlaTensor(self), eigenvectors, upper); + auto results = tensor_methods::linalg_eigh(bridge::GetXlaTensor(self), uplo); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index c6deba61b160..b33fff4fbca4 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -45,7 +45,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, bool eigenvectors, } // namespace SymEig::SymEig(const torch::lazy::Value& input, bool eigenvectors, bool lower) - : XlaNode(torch::lazy::OpKind(at::aten::symeig), {input}, + : XlaNode(torch::lazy::OpKind(at::aten::linalg_eigh), {input}, [&]() { return NodeOutputShape(input, eigenvectors, lower); }, /*num_outputs=*/2, torch::lazy::MHash(eigenvectors, lower)), eigenvectors_(eigenvectors), diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2fa581e8c018..a7af9038ed6e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include "absl/strings/str_cat.h" @@ -2387,6 +2388,17 @@ XLATensorPtr sub(const XLATensorPtr& input, const XLATensorPtr& other, logical_element_type); } +std::tuple XLATensor::linalg_eigh( + const XLATensorPtr& input, c10::string_view uplo_str) { + // SymEig takes lower instead of upper, hence the negation. + char uplo = std::toupper(uplo_str[0]); + bool lower = (uplo == 'L'); + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), /*eigenvectors=*/true, lower); + return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1))); +} + XLATensorPtr sub(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -2427,15 +2439,6 @@ std::tuple svd( input->CreateFrom(torch::lazy::Value(node, 2))); } -std::tuple symeig(const XLATensorPtr& input, - bool eigenvectors, bool upper) { - // SymEig takes lower instead of upper, hence the negation. - torch::lazy::NodePtr node = - torch::lazy::MakeNode(input->GetIrValue(), eigenvectors, !upper); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); -} - XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output) { return mul(grad_output, rsub(pow(output, 2), 1, 1)); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index a4a073a6b645..1130cee29dec 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -449,6 +449,9 @@ XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, const at::Scalar& weight); +std::tuple linalg_eigh(const XLATensorPtr& input, + c10::string_view uplo); + XLATensorPtr linspace(const at::Scalar& start, const at::Scalar& end, const int64_t steps, at::ScalarType element_type, const torch::lazy::BackendDevice& device); @@ -805,9 +808,6 @@ XLATensorPtr sum(const XLATensorPtr& input, std::vector dimensions, std::tuple svd( const XLATensorPtr& input, bool some, bool compute_uv); -std::tuple symeig(const XLATensorPtr& input, - bool eigenvectors, bool upper); - XLATensorPtr take(const XLATensorPtr& input, const XLATensorPtr& index); XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index a503a0576030..0236dcf46a1c 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -197,6 +197,7 @@ supported: - leaky_relu_backward - lerp.Scalar - lerp.Tensor + - linalg_eigh - linspace - log - log1p @@ -307,7 +308,6 @@ supported: - sum - sum.dim_IntList - svd - - symeig - t - t_ - tanh_backward