From ac3973e96086a42232cd726626b8165cd13b6ee9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 24 Jan 2023 14:04:46 +0200 Subject: [PATCH] Squashed commit of the following: commit 61c7b32e062aa7bb9c06980b227610009657a6eb Author: Ivan Yashchuk Date: Tue Jan 24 13:42:13 2023 +0200 clang-format commit 24540af59c3a10cd375977f1ee9e68fd7633b250 Author: Ivan Yashchuk Date: Tue Jan 24 13:40:24 2023 +0200 Remove symeig, add linalg_eigh to the header commit 7247ed8c601cf416c1177267c3d28caa0d0699db Merge: b83ae644 70cc6295 Author: Ivan Yashchuk Date: Tue Jan 24 13:35:03 2023 +0200 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit b83ae64473718616a7bc4103ef28a47f28084ffe Merge: e4255897 281f3696 Author: Ivan Yashchuk Date: Wed Nov 23 11:01:53 2022 +0200 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit e4255897ceac2233988776a7ff5b7cfd9fa92f56 Merge: 7eb9a56b a41a04bd Author: Ivan Yashchuk Date: Sat Jan 29 16:45:49 2022 +0000 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit 7eb9a56baedcb5ab392fd76207b9482e49fb3db1 Author: Ivan Yashchuk Date: Sat Jan 29 11:11:38 2022 +0000 Revert changes to test/test_ops.py commit 3ca7b743238e1e3dfc0ec539068dd7ad53d80c11 Author: Ivan Yashchuk Date: Thu Jan 27 11:15:13 2022 +0000 clang-format-7 commit be7fd91174984f9da71ffe845fa0c95ac5a98d49 Author: Ivan Yashchuk Date: Thu Jan 27 10:30:26 2022 +0000 Skip eigh OpInfo test commit adae41f5401641ad7e3ae3c0991f31c5005d7630 Author: Ivan Yashchuk Date: Thu Jan 27 10:26:58 2022 +0000 clang-format-7 commit a820fa7f41a1333b70eabfc7d645a76758d440be Merge: b826a4b8 138a70f4 Author: Ivan Yashchuk Date: Thu Jan 27 10:17:17 2022 +0000 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit b826a4b85af04a00888d1c3c18abb74d4ba984e8 Author: Ivan Yashchuk Date: Thu Jan 27 10:17:12 2022 +0000 Revert "Try branch with linalg_qr symbol" This reverts commit d01e622afba7ffe219be972280528b1d5cf0e30b. commit d01e622afba7ffe219be972280528b1d5cf0e30b Author: Ivan Yashchuk Date: Thu Jan 13 01:41:21 2022 -0600 Try branch with linalg_qr symbol commit 83b1fa2297f8b2eeb61d51571d8bed3f536bb1c7 Author: Ivan Yashchuk Date: Fri Jan 14 03:46:35 2022 -0600 clang-format-7 commit 518d47a0116fff114c516f3091e3ce2dc0abfe84 Merge: 37f3e0e3 7ff6a2c7 Author: Ivan Yashchuk Date: Fri Jan 14 03:45:16 2022 -0600 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit 37f3e0e3fed565c917be8d478ed4340fad4e2d7f Merge: d6c20a2b ddbc3330 Author: Ivan Yashchuk Date: Fri Jan 7 08:24:31 2022 +0000 Merge remote-tracking branch 'upstream/master' into remove-deprecated-symeig commit d6c20a2b11353adcc22df1252f09118203cc4be5 Author: Ivan Yashchuk Date: Mon Jan 3 14:34:57 2022 +0000 linalg::eigh -> linalg_eigh commit ac6db4d1bf57fe7fdc51def55b56abd8d880bb89 Author: Ivan Yashchuk Date: Mon Jan 3 14:23:13 2022 +0000 Removed torch.symeig and added torch.linalg.eigh --- test/cpp/test_aten_xla_tensor.cpp | 36 +++++++++++++------------------ torch_xla/csrc/aten_xla_type.cpp | 7 +++--- torch_xla/csrc/ops/symeig.cpp | 2 +- torch_xla/csrc/tensor_methods.cpp | 21 ++++++++++-------- torch_xla/csrc/tensor_methods.h | 6 +++--- xla_native_functions.yaml | 2 +- 6 files changed, 35 insertions(+), 39 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 3c78805618eb..3064f2c95416 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -934,29 +934,23 @@ TEST_F(AtenXlaTensorTest, TestQR) { } } -TEST_F(AtenXlaTensorTest, TestSymEig) { +TEST_F(AtenXlaTensorTest, TestLinalgEigh) { static const int dims[] = {4, 7}; for (auto m : dims) { - for (bool eigenvectors : {true, false}) { - for (bool upper : {true, false}) { - torch::Tensor a = - torch::rand({m, m}, torch::TensorOptions(torch::kFloat)); - torch::Tensor sym_a = a.mm(a.t()); - auto b = torch::symeig(sym_a, eigenvectors, upper); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(sym_a, device); - auto xla_b = torch::symeig(xla_a, eigenvectors, upper); - AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2, - /*atol=*/1e-2); - if (eigenvectors) { - AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), - /*rtol=*/3e-2, - /*atol=*/1e-2); - } else { - EXPECT_EQ(std::get<1>(b).sizes(), std::get<1>(xla_b).sizes()); - } - }); - } + for (std::string uplo : {"U", "L"}) { + torch::Tensor a = + torch::rand({m, m}, torch::TensorOptions(torch::kFloat)); + torch::Tensor sym_a = a.mm(a.t()); + auto b = torch::linalg_eigh(sym_a, uplo); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(sym_a, device); + auto xla_b = torch::linalg_eigh(xla_a, uplo); + AllClose(std::get<0>(b), std::get<0>(xla_b), /*rtol=*/3e-2, + /*atol=*/1e-2); + AllClose(std::get<1>(b).abs(), std::get<1>(xla_b).abs(), + /*rtol=*/3e-2, + /*atol=*/1e-2); + }); } } } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9580fba5f606..08383af6243f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2777,11 +2777,10 @@ std::tuple XLANativeFunctions::svd( bridge::AtenFromXlaTensor(std::get<2>(results))); } -std::tuple XLANativeFunctions::symeig( - const at::Tensor& self, bool eigenvectors, bool upper) { +std::tuple XLANativeFunctions::linalg_eigh( + const at::Tensor& self, c10::string_view uplo) { TORCH_LAZY_FN_COUNTER("xla::"); - auto results = - tensor_methods::symeig(bridge::GetXlaTensor(self), eigenvectors, upper); + auto results = tensor_methods::linalg_eigh(bridge::GetXlaTensor(self), uplo); return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(results)), bridge::AtenFromXlaTensor(std::get<1>(results))); } diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index c6deba61b160..b33fff4fbca4 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -45,7 +45,7 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, bool eigenvectors, } // namespace SymEig::SymEig(const torch::lazy::Value& input, bool eigenvectors, bool lower) - : XlaNode(torch::lazy::OpKind(at::aten::symeig), {input}, + : XlaNode(torch::lazy::OpKind(at::aten::linalg_eigh), {input}, [&]() { return NodeOutputShape(input, eigenvectors, lower); }, /*num_outputs=*/2, torch::lazy::MHash(eigenvectors, lower)), eigenvectors_(eigenvectors), diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2fa581e8c018..a7af9038ed6e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include "absl/strings/str_cat.h" @@ -2387,6 +2388,17 @@ XLATensorPtr sub(const XLATensorPtr& input, const XLATensorPtr& other, logical_element_type); } +std::tuple XLATensor::linalg_eigh( + const XLATensorPtr& input, c10::string_view uplo_str) { + // SymEig takes lower instead of upper, hence the negation. + char uplo = std::toupper(uplo_str[0]); + bool lower = (uplo == 'L'); + torch::lazy::NodePtr node = torch::lazy::MakeNode( + input->GetIrValue(), /*eigenvectors=*/true, lower); + return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1))); +} + XLATensorPtr sub(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { @@ -2427,15 +2439,6 @@ std::tuple svd( input->CreateFrom(torch::lazy::Value(node, 2))); } -std::tuple symeig(const XLATensorPtr& input, - bool eigenvectors, bool upper) { - // SymEig takes lower instead of upper, hence the negation. - torch::lazy::NodePtr node = - torch::lazy::MakeNode(input->GetIrValue(), eigenvectors, !upper); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); -} - XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output) { return mul(grad_output, rsub(pow(output, 2), 1, 1)); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index a4a073a6b645..1130cee29dec 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -449,6 +449,9 @@ XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end, const at::Scalar& weight); +std::tuple linalg_eigh(const XLATensorPtr& input, + c10::string_view uplo); + XLATensorPtr linspace(const at::Scalar& start, const at::Scalar& end, const int64_t steps, at::ScalarType element_type, const torch::lazy::BackendDevice& device); @@ -805,9 +808,6 @@ XLATensorPtr sum(const XLATensorPtr& input, std::vector dimensions, std::tuple svd( const XLATensorPtr& input, bool some, bool compute_uv); -std::tuple symeig(const XLATensorPtr& input, - bool eigenvectors, bool upper); - XLATensorPtr take(const XLATensorPtr& input, const XLATensorPtr& index); XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index a503a0576030..0236dcf46a1c 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -197,6 +197,7 @@ supported: - leaky_relu_backward - lerp.Scalar - lerp.Tensor + - linalg_eigh - linspace - log - log1p @@ -307,7 +308,6 @@ supported: - sum - sum.dim_IntList - svd - - symeig - t - t_ - tanh_backward