From 60eed6a71a3256aa7f6deba34f1ebbd184d06a6d Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 29 Jun 2018 18:13:57 -0400 Subject: [PATCH 01/13] Implement torch.pinv : Pseudo-inverse 1. Used SVD to compute. 2. Tests in test_cuda and test_torch 3. Doc strings in _torch_docs.py and _tensor_docs.py Closes #6187 --- aten/src/ATen/native/LinearAlgebra.cpp | 14 +++++++++ aten/src/ATen/native/native_functions.yaml | 2 ++ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_cuda.py | 4 +++ test/test_torch.py | 19 +++++++++++++ torch/_tensor_docs.py | 7 +++++ torch/_torch_docs.py | 33 ++++++++++++++++++++++ 8 files changed, 81 insertions(+) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index ff7b20bde3c17..3e38ca96214ff 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -96,6 +96,20 @@ std::tuple slogdet(const Tensor& self) { return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } +Tensor pinv(const Tensor& self) { + if (!at::isFloatingType(self.type().scalarType()) || + self.dim() != 2) { + std::ostringstream ss; + ss << "pinverse(" << self.type() << "{" << self.sizes() << "}): expected a " + << "2D tensor of floating types"; + throw std::runtime_error(ss.str()); + } + Tensor U, S, V; + std::tie(U, S, V) = self.svd(); + Tensor S_pseudoinv = at::where(S != 0.0, S.reciprocal(), at::zeros({}, self.type())); + return V.mm(S_pseudoinv.diag().mm(U.t())); +} + static void check_1d(const Tensor& t, const char* arg, const char* fn) { if (t.dim() != 1) { AT_ERROR(fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 452ac21f162cd..48b83393e4cd2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -903,6 +903,8 @@ - func: pin_memory(Tensor self) -> Tensor +- func: pinv(Tensor self) -> Tensor + - func: rand(IntList size, *, TensorOptions options={}) -> Tensor variants: function diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 5ff9f441d6257..98d72ec306b49 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it. .. automethod:: ormqr .. automethod:: permute .. automethod:: pin_memory + .. automethod:: pinv .. automethod:: potrf .. automethod:: potri .. automethod:: potrs diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 3c6e6aa367d89..8a9088ea078ac 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -288,6 +288,7 @@ BLAS and LAPACK Operations .. autofunction:: mv .. autofunction:: orgqr .. autofunction:: ormqr +.. autofunction:: pinv .. autofunction:: potrf .. autofunction:: potri .. autofunction:: potrs diff --git a/test/test_cuda.py b/test/test_cuda.py index fe2560dcf472c..20c4def304ee6 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1367,6 +1367,10 @@ def test_caching_pinned_memory_multi_gpu(self): def _select_broadcastable_dims(dims_full=None): return TestTorch._select_broadcastable_dims(dims_full) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_pinverse(self): + TestTorch._test_pinverse(self, lambda t: t.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_det_logdet_slogdet(self): TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda()) diff --git a/test/test_torch.py b/test/test_torch.py index b065e0ab4cc38..042261ca16504 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4038,6 +4038,25 @@ def test_inverse(self): self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place') + @staticmethod + def _test_pinverse(self, conv_fn): + def run_test(M, conv_fn): + MPI = torch.pinverse(M) + E = conv_fn(torch.eye(5)) + self.assertEqual(E, torch.mm(M, MPI), 1e-7, 'pseudo-inverse value') + self.assertEqual(E, torch.mm(MPI, M), 1e-7, 'pseudo-inverse value') + + # Square matrix + M = conv_fn(torch.randn(5, 5)) + run_test(M, conv_fn) + # Rectangular matrix + M = conv_fn(torch.randn(3, 4)) + run_test(M, conv_fn) + + @skipIfNoLapack + def test_pinverse(self): + self._test_pinverse(conv_fn=lambda x: x) + @staticmethod def _test_det_logdet_slogdet(self, conv_fn): def reference_det(M): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 948b3bf1e91ae..ad613ad2f8670 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2520,3 +2520,10 @@ def callable(a, b) -> number See :func:`torch.slogdet` """) + +add_docstr_all('pinv', + r""" +pinv() -> Tensor + +See :func:`torch.pinv` +""") diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b0db3911bb2ad..26680965c441e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5264,6 +5264,39 @@ def parse_kwargs(desc): (tensor(-1.), tensor(1.5731)) """) +add_docstr(torch.pinv, + r""" +pinv(input) -> Tensor + +Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor. +Please look at `Moore-Penrose inverse`_ for more details + +.. note:: + This method is implemented using the Singular Value Decomposition. + +Arguments: + input (Tensor): The input 2D tensor of dimensions :math:`m \times n` + +Returns: + The pseudo-inverse of :attr:`input` of dimensions :math:`n \times m` + +Example:: + + >>> input = torch.randn(3, 5) + >>> input + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.pinv(input) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + +.. _Moore-Penrose inverse: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse +""") + add_docstr(torch.fft, r""" fft(input, signal_ndim, normalized=False) -> Tensor From d985492c781bd3023dd11d43da56fd154ed8e2b7 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 29 Jun 2018 18:52:23 -0400 Subject: [PATCH 02/13] Fix nit in tests --- test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 042261ca16504..5ac0468d7afb3 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4055,7 +4055,7 @@ def run_test(M, conv_fn): @skipIfNoLapack def test_pinverse(self): - self._test_pinverse(conv_fn=lambda x: x) + self._test_pinverse(self, conv_fn=lambda x: x) @staticmethod def _test_det_logdet_slogdet(self, conv_fn): From d4b601ac3ad83f60a7e6fe768040f268d9b4886c Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 29 Jun 2018 19:10:50 -0400 Subject: [PATCH 03/13] Fix one more nit in test --- test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 5ac0468d7afb3..2a7a8678e26c1 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4041,7 +4041,7 @@ def test_inverse(self): @staticmethod def _test_pinverse(self, conv_fn): def run_test(M, conv_fn): - MPI = torch.pinverse(M) + MPI = torch.pinv(M) E = conv_fn(torch.eye(5)) self.assertEqual(E, torch.mm(M, MPI), 1e-7, 'pseudo-inverse value') self.assertEqual(E, torch.mm(MPI, M), 1e-7, 'pseudo-inverse value') From 332b0befcca72153b470b52a1d23c7a7a5f6533a Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 29 Jun 2018 19:30:14 -0400 Subject: [PATCH 04/13] Fix tests finally, sorry --- test/test_torch.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 2a7a8678e26c1..8aef2bac8a50b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4040,18 +4040,26 @@ def test_inverse(self): @staticmethod def _test_pinverse(self, conv_fn): - def run_test(M, conv_fn): + def run_test(M): + # Testing against identities for pseudo-inverses MPI = torch.pinv(M) - E = conv_fn(torch.eye(5)) - self.assertEqual(E, torch.mm(M, MPI), 1e-7, 'pseudo-inverse value') - self.assertEqual(E, torch.mm(MPI, M), 1e-7, 'pseudo-inverse value') + self.assertEqual(MPI, MPI.mm(MPI.t()).mm(M.t()), 1e-8, 'pseudo-inverse identity 1') + self.assertEqual(MPI, M.t().mm(MPI.t()).mm(MPI), 1e-8, 'pseudo-inverse identity 2') + self.assertEqual(M, MPI.t().mm(M.t()).mm(M), 1e-8, 'pseudo-inverse identity 3') + self.assertEqual(M, M.mm(M.t()).mm(MPI.t()), 1e-8, 'pseudo-inverse idenity 4') # Square matrix M = conv_fn(torch.randn(5, 5)) - run_test(M, conv_fn) + run_test(M) + # Rectangular matrix M = conv_fn(torch.randn(3, 4)) - run_test(M, conv_fn) + run_test(M) + + # Test inverse and pseudo-inverse for invertible matrix + M = torch.randn(5, 5) + M = conv_fn(M.mm(M.t())) + self.assertEqual(conv_fn(torch.eye(5)), M.pinv().mm(M), 1e-7, 'pseudo-inverse for invertible matrix') @skipIfNoLapack def test_pinverse(self): From 33c75a0a887ff856b721488b75dd8329bc4c08f4 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Fri, 29 Jun 2018 21:29:47 -0400 Subject: [PATCH 05/13] Add actual Moore-Penrose conditions in the tests, rename pinv to pinverse --- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- aten/src/ATen/native/native_functions.yaml | 2 +- docs/source/tensors.rst | 2 +- docs/source/torch.rst | 2 +- test/test_torch.py | 14 +++++++------- torch/_tensor_docs.py | 6 +++--- torch/_torch_docs.py | 6 +++--- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 3e38ca96214ff..51cec7725bd5d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -96,7 +96,7 @@ std::tuple slogdet(const Tensor& self) { return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } -Tensor pinv(const Tensor& self) { +Tensor pinverse(const Tensor& self) { if (!at::isFloatingType(self.type().scalarType()) || self.dim() != 2) { std::ostringstream ss; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 48b83393e4cd2..080a7844be7ef 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -903,7 +903,7 @@ - func: pin_memory(Tensor self) -> Tensor -- func: pinv(Tensor self) -> Tensor +- func: pinverse(Tensor self) -> Tensor - func: rand(IntList size, *, TensorOptions options={}) -> Tensor variants: function diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 98d72ec306b49..135c15c8432e1 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -306,7 +306,7 @@ view of a storage and defines numeric operations on it. .. automethod:: ormqr .. automethod:: permute .. automethod:: pin_memory - .. automethod:: pinv + .. automethod:: pinverse .. automethod:: potrf .. automethod:: potri .. automethod:: potrs diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 8a9088ea078ac..e7cf3ef38a12f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -288,7 +288,7 @@ BLAS and LAPACK Operations .. autofunction:: mv .. autofunction:: orgqr .. autofunction:: ormqr -.. autofunction:: pinv +.. autofunction:: pinverse .. autofunction:: potrf .. autofunction:: potri .. autofunction:: potrs diff --git a/test/test_torch.py b/test/test_torch.py index 8aef2bac8a50b..7b58148a21449 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4041,12 +4041,12 @@ def test_inverse(self): @staticmethod def _test_pinverse(self, conv_fn): def run_test(M): - # Testing against identities for pseudo-inverses - MPI = torch.pinv(M) - self.assertEqual(MPI, MPI.mm(MPI.t()).mm(M.t()), 1e-8, 'pseudo-inverse identity 1') - self.assertEqual(MPI, M.t().mm(MPI.t()).mm(MPI), 1e-8, 'pseudo-inverse identity 2') - self.assertEqual(M, MPI.t().mm(M.t()).mm(M), 1e-8, 'pseudo-inverse identity 3') - self.assertEqual(M, M.mm(M.t()).mm(MPI.t()), 1e-8, 'pseudo-inverse idenity 4') + # Testing against definition for pseudo-inverses + MPI = torch.pinverse(M) + self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1') + self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2') + self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3') + self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4') # Square matrix M = conv_fn(torch.randn(5, 5)) @@ -4059,7 +4059,7 @@ def run_test(M): # Test inverse and pseudo-inverse for invertible matrix M = torch.randn(5, 5) M = conv_fn(M.mm(M.t())) - self.assertEqual(conv_fn(torch.eye(5)), M.pinv().mm(M), 1e-7, 'pseudo-inverse for invertible matrix') + self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix') @skipIfNoLapack def test_pinverse(self): diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index ad613ad2f8670..b7a9d64374ebc 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2521,9 +2521,9 @@ def callable(a, b) -> number See :func:`torch.slogdet` """) -add_docstr_all('pinv', +add_docstr_all('pinverse', r""" -pinv() -> Tensor +pinverse() -> Tensor -See :func:`torch.pinv` +See :func:`torch.pinverse` """) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 26680965c441e..d897667838ce9 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5264,9 +5264,9 @@ def parse_kwargs(desc): (tensor(-1.), tensor(1.5731)) """) -add_docstr(torch.pinv, +add_docstr(torch.pinverse, r""" -pinv(input) -> Tensor +pinverse(input) -> Tensor Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor. Please look at `Moore-Penrose inverse`_ for more details @@ -5287,7 +5287,7 @@ def parse_kwargs(desc): tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) - >>> torch.pinv(input) + >>> torch.pinverse(input) tensor([[ 0.0600, -0.1933, -0.2090], [-0.0903, -0.0817, -0.4752], [-0.7124, -0.1631, -0.2272], From bb047cc4f6d93162f0e7565e747c409a06408b90 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sat, 30 Jun 2018 08:31:31 -0400 Subject: [PATCH 06/13] Address comments 1. Use AT_CHECK 2. Use at::zeros({}, self.options); --- aten/src/ATen/native/LinearAlgebra.cpp | 56 +++++++++----------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 51cec7725bd5d..74d7b91ec2c54 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -19,11 +19,7 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor p.squeeze_(0); lu.squeeze_(0); int int_info = info.squeeze_().toCInt(); - if (int_info < 0) { - std::ostringstream ss; - ss << "LU factorization (getrf) failed with info = " << int_info; - throw std::runtime_error(ss.str()); - } + AT_CHECK(int_info < 0, "LU factorization (getrf) failed with info = ", int_info); auto n = self.size(0); auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0); if (num_exchanges % 2 == 1) { @@ -34,13 +30,10 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor } Tensor det(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D " - << "square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(!at::isFloatingType(self.type().scalarType()) || + self.dim() != 2 || self.size(0) != self.size(1), + "det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U; int info; @@ -53,13 +46,10 @@ Tensor det(const Tensor& self) { } Tensor logdet(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "logdet(" << self.type() << "{" << self.sizes() << "}): expected a " - << "2D square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(!at::isFloatingType(self.type().scalarType()) || + self.dim() != 2 || self.size(0) != self.size(1), + "logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U, det; int info; @@ -77,13 +67,10 @@ Tensor logdet(const Tensor& self) { } std::tuple slogdet(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "slogdet(" << self.type() << "{" << self.sizes() << "}): expected a " - << "2D square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(!at::isFloatingType(self.type().scalarType()) || + self.dim() != 2 || self.size(0) != self.size(1), + "slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U, det; int info; @@ -97,23 +84,18 @@ std::tuple slogdet(const Tensor& self) { } Tensor pinverse(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2) { - std::ostringstream ss; - ss << "pinverse(" << self.type() << "{" << self.sizes() << "}): expected a " - << "2D tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(!at::isFloatingType(self.type().scalarType()) || + self.dim() != 2 || self.size(0) != self.size(1), + "pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " + "of floating types"); Tensor U, S, V; std::tie(U, S, V) = self.svd(); - Tensor S_pseudoinv = at::where(S != 0.0, S.reciprocal(), at::zeros({}, self.type())); + Tensor S_pseudoinv = at::where(S != 0.0, S.reciprocal(), at::zeros({}, self.options())); return V.mm(S_pseudoinv.diag().mm(U.t())); } static void check_1d(const Tensor& t, const char* arg, const char* fn) { - if (t.dim() != 1) { - AT_ERROR(fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); - } + AT_CHECK(t.dim() != 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); } Tensor ger(const Tensor& self, const Tensor& vec2) { From 1bb5614079c6685ce13dae37ec065fbad362b246 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sat, 30 Jun 2018 08:37:50 -0400 Subject: [PATCH 07/13] Invert bool expressions --- aten/src/ATen/native/LinearAlgebra.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 74d7b91ec2c54..a36170c0ef344 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -19,7 +19,7 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor p.squeeze_(0); lu.squeeze_(0); int int_info = info.squeeze_().toCInt(); - AT_CHECK(int_info < 0, "LU factorization (getrf) failed with info = ", int_info); + AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info); auto n = self.size(0); auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0); if (num_exchanges % 2 == 1) { @@ -30,8 +30,8 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor } Tensor det(const Tensor& self) { - AT_CHECK(!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1), + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), "det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " "of floating types"); double det_P; @@ -46,8 +46,8 @@ Tensor det(const Tensor& self) { } Tensor logdet(const Tensor& self) { - AT_CHECK(!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1), + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), "logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " "of floating types"); double det_P; @@ -67,8 +67,8 @@ Tensor logdet(const Tensor& self) { } std::tuple slogdet(const Tensor& self) { - AT_CHECK(!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1), + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), "slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " "of floating types"); double det_P; @@ -84,8 +84,7 @@ std::tuple slogdet(const Tensor& self) { } Tensor pinverse(const Tensor& self) { - AT_CHECK(!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1), + AT_CHECK(at::isFloatingType(self.type().scalarType()) && self.dim() == 2, "pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " "of floating types"); Tensor U, S, V; @@ -95,7 +94,7 @@ Tensor pinverse(const Tensor& self) { } static void check_1d(const Tensor& t, const char* arg, const char* fn) { - AT_CHECK(t.dim() != 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); + AT_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); } Tensor ger(const Tensor& self, const Tensor& vec2) { From 494c2a1de8d70bf49f1107fe6cd39dfc83fc1ace Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sat, 30 Jun 2018 14:35:43 -0400 Subject: [PATCH 08/13] Use NumPy-based cutoff for removing smaller singular values --- aten/src/ATen/native/LinearAlgebra.cpp | 5 +++-- aten/src/ATen/native/native_functions.yaml | 2 +- torch/_torch_docs.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index a36170c0ef344..ea87d42dfa58f 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -83,13 +83,14 @@ std::tuple slogdet(const Tensor& self) { return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } -Tensor pinverse(const Tensor& self) { +Tensor pinverse(const Tensor& self, double rcond) { AT_CHECK(at::isFloatingType(self.type().scalarType()) && self.dim() == 2, "pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " "of floating types"); Tensor U, S, V; std::tie(U, S, V) = self.svd(); - Tensor S_pseudoinv = at::where(S != 0.0, S.reciprocal(), at::zeros({}, self.options())); + double max_val = S[0].toCDouble(); + Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options())); return V.mm(S_pseudoinv.diag().mm(U.t())); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 080a7844be7ef..7428df616ba47 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -903,7 +903,7 @@ - func: pin_memory(Tensor self) -> Tensor -- func: pinverse(Tensor self) -> Tensor +- func: pinverse(Tensor self, double rcond=1e-15) -> Tensor - func: rand(IntList size, *, TensorOptions options={}) -> Tensor variants: function diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index d897667838ce9..748db8df8d7d4 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5266,7 +5266,7 @@ def parse_kwargs(desc): add_docstr(torch.pinverse, r""" -pinverse(input) -> Tensor +pinverse(input, rcond=1e-15) -> Tensor Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor. Please look at `Moore-Penrose inverse`_ for more details @@ -5276,6 +5276,8 @@ def parse_kwargs(desc): Arguments: input (Tensor): The input 2D tensor of dimensions :math:`m \times n` + rcond (float): A floating point value to determine the cutoff for small singular values. + Default: 1e-15 Returns: The pseudo-inverse of :attr:`input` of dimensions :math:`n \times m` From ac0e33c52389c84469c2b39a97ce43ef00a4675f Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sat, 30 Jun 2018 15:53:33 -0400 Subject: [PATCH 09/13] Add tests in test_autograd --- test/test_autograd.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index b74d7d0576e28..a887486df643b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2486,6 +2486,11 @@ def random_fullrank_matrix_distinct_singular_value(l): s = torch.arange(1., l + 1).mul_(1.0 / (l + 1)) return u.mm(torch.diag(s)).mm(v.t()) +def random_matrix_large_singular_value(m, n): + U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + V = torch.randn(n, m).qr()[0] # Orthogonal with dimensions n x m + S = torch.cat([torch.empty(m).uniform_(10., 20.), torch.zeros(n - m)], 0) + return U.mm(torch.diag(S)).mm(V) def uniform_scalar(offset=0, requires_grad=False): v = torch.rand(()) + offset @@ -2948,6 +2953,10 @@ class dont_convert(tuple): ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]), ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]), ('inverse', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), + ('pinverse', lambda: random_matrix_large_singular_value(S, M), NO_ARGS, + 'rectangular', NO_ARGS, [skipIfNoLapack]), + ('pinverse', lambda: random_matrix_large_singular_value(M, M), NO_ARGS, + 'square', NO_ARGS, [skipIfNoLapack]), ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]), ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]), From 61178015952387919935c80d5c40c61a0ef5d35e Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sun, 1 Jul 2018 08:48:08 -0400 Subject: [PATCH 10/13] Modify test as per suggestion, add note in the docs --- test/test_autograd.py | 21 ++++++++++++--------- torch/_torch_docs.py | 11 +++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index a887486df643b..926f697680f0c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2040,6 +2040,18 @@ def run_test(input_size, exponent): run_test((10, 10), torch.zeros(10, 10)) run_test((10,), 0) + def test_pinverse(self): + m, n = 5, 10 + U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + def func(x): + S = torch.cat([x, torch.zeros(n - m)], 0) + M = U.mm(torch.diag(S)).mm(V.t()) + return M.pinverse() + + gradcheck(func, [torch.rand(m) + 1]) + gradcheck(func, [torch.rand(m) + 10]) + def test_profiler(self): x = torch.randn(10, 10) @@ -2486,11 +2498,6 @@ def random_fullrank_matrix_distinct_singular_value(l): s = torch.arange(1., l + 1).mul_(1.0 / (l + 1)) return u.mm(torch.diag(s)).mm(v.t()) -def random_matrix_large_singular_value(m, n): - U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n - V = torch.randn(n, m).qr()[0] # Orthogonal with dimensions n x m - S = torch.cat([torch.empty(m).uniform_(10., 20.), torch.zeros(n - m)], 0) - return U.mm(torch.diag(S)).mm(V) def uniform_scalar(offset=0, requires_grad=False): v = torch.rand(()) + offset @@ -2953,10 +2960,6 @@ class dont_convert(tuple): ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]), ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]), ('inverse', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), - ('pinverse', lambda: random_matrix_large_singular_value(S, M), NO_ARGS, - 'rectangular', NO_ARGS, [skipIfNoLapack]), - ('pinverse', lambda: random_matrix_large_singular_value(M, M), NO_ARGS, - 'square', NO_ARGS, [skipIfNoLapack]), ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]), ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]), diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 748db8df8d7d4..6bbb9e44b0a66 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5274,6 +5274,13 @@ def parse_kwargs(desc): .. note:: This method is implemented using the Singular Value Decomposition. +.. note:: + The pseudo-inverse is not necessarily a continuous function in the elements of the matrix `[1]`_. + Therefore, derivatives are not always existent, and exist for a constant rank only `[2]`_. + However, this method is backprop-able due to the implementation by using SVD results, and + could be unstable. Double-backward will also be unstable due to the usage of SVD internally. + See :meth:`~torch.svd` for more details. + Arguments: input (Tensor): The input 2D tensor of dimensions :math:`m \times n` rcond (float): A floating point value to determine the cutoff for small singular values. @@ -5297,6 +5304,10 @@ def parse_kwargs(desc): [-0.0308, -0.1725, -0.5216]]) .. _Moore-Penrose inverse: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse + +.. _[1]: https://epubs.siam.org/doi/10.1137/0117004 + +.. _[2]: https://www.jstor.org/stable/2156365 """) add_docstr(torch.fft, From ce3748d7c2368b1629682029f51c08923651e6ea Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sun, 1 Jul 2018 08:56:16 -0400 Subject: [PATCH 11/13] Fix lint --- test/test_autograd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index 926f697680f0c..f070646151aa7 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2044,6 +2044,7 @@ def test_pinverse(self): m, n = 5, 10 U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + def func(x): S = torch.cat([x, torch.zeros(n - m)], 0) M = U.mm(torch.diag(S)).mm(V.t()) From 16720f31e269d9e968c2cc8ef07b2b07b6f575b1 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Mon, 2 Jul 2018 13:19:23 -0400 Subject: [PATCH 12/13] add gradgradcheck --- test/test_autograd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index 3f215fa2116ed..a6de62212e45a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2052,6 +2052,8 @@ def func(x): gradcheck(func, [torch.rand(m) + 1]) gradcheck(func, [torch.rand(m) + 10]) + gradgradcheck(func, [torch.rand(m) + 1]) + gradgradcheck(func, [torch.rand(m) + 10]) def test_profiler(self): x = torch.randn(10, 10) From 6bac86610365da9f0afe3a4d7f12275295a2918a Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Tue, 3 Jul 2018 14:46:24 -0400 Subject: [PATCH 13/13] Update test_autograd.py --- test/test_autograd.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index a6de62212e45a..6022b75b62534 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2041,6 +2041,14 @@ def run_test(input_size, exponent): run_test((10,), 0) def test_pinverse(self): + # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? + # 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable + # 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973) + # 3. This method creates two orthogonal matrices, and a constructs a test case with large + # singular values (given by x to the function). + # 4. This will ensure that small perturbations don't affect the rank of matrix, in which case + # a derivative exists. + # 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method m, n = 5, 10 U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n