From 3923d4eab777fa1afeddf6e95a8e839e2e980cac Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 28 Sep 2020 11:05:02 -0500 Subject: [PATCH 01/63] linalg.svd, step 1: rename the old svd into linalg_svd, and reimplement svd as a tiny wrapper around it --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 25 +++++++++++++------ .../ATen/native/cuda/BatchLinearAlgebra.cu | 2 +- aten/src/ATen/native/native_functions.yaml | 24 ++++++++++++------ tools/autograd/derivatives.yaml | 4 +-- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index f4babb2a14a3..aa9f20aeadb4 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -935,7 +935,7 @@ std::tuple symeig_out(Tensor& vals, Tensor& vecs, const Tensor return std::tuple(vals, vecs); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, @@ -1004,7 +1004,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, #endif } -std::tuple _svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { +std::tuple _linalg_svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1042,24 +1042,35 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } -std::tuple svd(const Tensor& self, bool some, bool compute_uv) { +std::tuple linalg_svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - return at::_svd_helper(self, some, compute_uv); + return at::_linalg_svd_helper(self, some, compute_uv); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, - const Tensor& self, bool some, bool compute_uv) { +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv); + std::tie(U_tmp, S_tmp, VT_tmp) = at::_linalg_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); VT.resize_as_(VT_tmp).copy_(VT_tmp); return std::tuple(U, S, VT); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +std::tuple svd(const Tensor& self, bool some, bool compute_uv) { + return at::native::linalg_svd(self, some, compute_uv); +} + +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, + const Tensor& self, bool some, bool compute_uv) { + return at::native::linalg_svd_out(U, S, VT, self, some, compute_uv); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index c86f355a67c2..e75df4141ca8 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1371,7 +1371,7 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { +std::tuple _linalg_svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f5bbb263ed9c..e56b5da7240e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5547,13 +5547,6 @@ use_c10_dispatcher: full variants: method, function -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) - use_c10_dispatcher: full - variants: function - dispatch: - CPU: _svd_helper_cpu - CUDA: _svd_helper_cuda - - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) - func: cholesky(Tensor self, bool upper=False) -> Tensor @@ -8019,6 +8012,23 @@ python_module: linalg variants: function +- func: linalg_svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + python_module: linalg + +- func: linalg_svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + python_module: linalg + use_c10_dispatcher: full + variants: method, function + +- func: _linalg_svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + CPU: _linalg_svd_helper_cpu + CUDA: _linalg_svd_helper_cuda + + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 70ddaee5226f..656b9222db58 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1011,8 +1011,8 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - self: svd_backward(grads, self, some, compute_uv, U, S, V) +- name: linalg_svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + self: svd_backward(grads, self, some, compute_uv, U, S, V) # XXX rename? - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) From 130cd04df88a84fb97ada6844874074bffa62025 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 28 Sep 2020 11:10:35 -0500 Subject: [PATCH 02/63] rename svd_backward into linalg_svd_backward, for consistency --- tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 10 +++++----- torch/csrc/autograd/FunctionsManual.h | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 656b9222db58..c268acd60803 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1012,7 +1012,7 @@ self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) - name: linalg_svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - self: svd_backward(grads, self, some, compute_uv, U, S, V) # XXX rename? + self: linalg_svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 1e73ebac2a2a..007009c9a653 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1655,10 +1655,10 @@ std::tuple prelu_double_backward( // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. -Tensor svd_backward(const std::vector &grads, const Tensor& self, +Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { TORCH_CHECK(compute_uv, - "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", + "linalg_svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); auto m = self.size(-2); @@ -1952,7 +1952,7 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); - return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { @@ -2002,7 +2002,7 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo std::tie(u, sigma, v) = self.svd(); // logdet = \sum log(sigma) auto gsigma = grad.unsqueeze(-1).div(sigma); - return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { @@ -2054,7 +2054,7 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, // so logabsdet = \sum log(abs(sigma)) // but det = 0, so backward logabsdet = \sum log(sigma) auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); - return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 8fd0e9b08cc4..41e774c2e05c 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -119,7 +119,7 @@ at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Te at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); -Tensor svd_backward(const std::vector &grads, const Tensor& self, +Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v); From f39dc152f8dc66df1a2911198b466810c03694d7 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 29 Sep 2020 04:09:06 -0500 Subject: [PATCH 03/63] change the signature of linalg_svd: now it takes full_matrices=true, which is the OPPOSITE (and with a different default) thatn the old some=true. The signature of at::svd is unchanged --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 17 ++++++++++------- aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 3 ++- aten/src/ATen/native/native_functions.yaml | 6 +++--- tools/autograd/derivatives.yaml | 4 ++-- torch/csrc/autograd/FunctionsManual.cpp | 9 +++++---- torch/csrc/autograd/FunctionsManual.h | 2 +- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index aa9f20aeadb4..2e6948681bad 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1004,7 +1004,8 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, #endif } -std::tuple _linalg_svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { +std::tuple _linalg_svd_helper_cpu(const Tensor& self, bool full_matrices, bool compute_uv) { + bool some = not full_matrices; std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1042,18 +1043,18 @@ std::tuple _linalg_svd_helper_cpu(const Tensor& self, bo return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } -std::tuple linalg_svd(const Tensor& self, bool some, bool compute_uv) { +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - return at::_linalg_svd_helper(self, some, compute_uv); + return at::_linalg_svd_helper(self, full_matrices, compute_uv); } std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, - const Tensor& self, bool some, bool compute_uv) { + const Tensor& self, bool full_matrices, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_linalg_svd_helper(self, some, compute_uv); + std::tie(U_tmp, S_tmp, VT_tmp) = at::_linalg_svd_helper(self, full_matrices, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); VT.resize_as_(VT_tmp).copy_(VT_tmp); @@ -1063,12 +1064,14 @@ std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tenso // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ std::tuple svd(const Tensor& self, bool some, bool compute_uv) { - return at::native::linalg_svd(self, some, compute_uv); + bool full_matrices = !some; + return at::linalg_svd(self, full_matrices, compute_uv); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, const Tensor& self, bool some, bool compute_uv) { - return at::native::linalg_svd_out(U, S, VT, self, some, compute_uv); + bool full_matrices = !some; + return at::linalg_svd_out(U, S, VT, self, full_matrices, compute_uv); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index e75df4141ca8..5cf918a15d4f 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1371,7 +1371,8 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _linalg_svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { +std::tuple _linalg_svd_helper_cuda(const Tensor& self, bool full_matrices, bool compute_uv) { + bool some = !full_matrices; std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e56b5da7240e..f0210c474d34 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8012,15 +8012,15 @@ python_module: linalg variants: function -- func: linalg_svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) +- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) python_module: linalg -- func: linalg_svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) +- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) python_module: linalg use_c10_dispatcher: full variants: method, function -- func: _linalg_svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) +- func: _linalg_svd_helper(Tensor self, bool full_matrices, bool compute_uv) -> (Tensor, Tensor, Tensor) python_module: linalg use_c10_dispatcher: full variants: function diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c268acd60803..7f6e142ab95d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1011,8 +1011,8 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: linalg_svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - self: linalg_svd_backward(grads, self, some, compute_uv, U, S, V) +- name: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + self: linalg_svd_backward(grads, self, full_matrices, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 007009c9a653..1948e67d2f6a 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1656,11 +1656,12 @@ std::tuple prelu_double_backward( // // This makes no assumption on the signs of sigma. Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, - bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { + bool full_matrices, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { TORCH_CHECK(compute_uv, "linalg_svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); + bool some = !full_matrices; auto m = self.size(-2); auto n = self.size(-1); auto k = sigma.size(-1); @@ -1952,7 +1953,7 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); - return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { @@ -2002,7 +2003,7 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo std::tie(u, sigma, v) = self.svd(); // logdet = \sum log(sigma) auto gsigma = grad.unsqueeze(-1).div(sigma); - return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { @@ -2054,7 +2055,7 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, // so logabsdet = \sum log(abs(sigma)) // but det = 0, so backward logabsdet = \sum log(sigma) auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); - return linalg_svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); + return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 41e774c2e05c..6926448143f1 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -120,7 +120,7 @@ at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, co at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, - bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); + bool full_matrices, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v); std::tuple triangular_solve_backward( From e16fde1498d687b7985b26cede0c509ebe76f816 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 29 Sep 2020 04:42:35 -0500 Subject: [PATCH 04/63] add a test for torch.linalg.svd, and write a numpy-compatible wrapper for it --- test/test_linalg.py | 14 ++++++++++++++ torch/linalg/__init__.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index c81b4dc37582..0b5fd200b310 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -568,6 +568,20 @@ def test_norm_fastpaths(self, device): expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) self.assertEqual(result, expected) + # Tests torch.linalg.svd, vs. NumPy + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.double) + def test_svd(self, device, dtype): + t = torch.randn((10, 10), device=device, dtype=dtype) + np_t = t.cpu().numpy() + for full_matrices in (True, False): + for compute_uv in (True, False): + expected = np.linalg.svd(np_t, full_matrices, compute_uv) + actual = torch.linalg.svd(t, full_matrices, compute_uv) + self.assertEqual(actual, expected) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c80..6a97e449662e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -140,3 +140,17 @@ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) + +def svd(a, full_matrices=True, compute_uv=True, hermitian=False): + """ + XXX write docstring + """ + assert not hermitian # XXX implement me + USV = _linalg.linalg_svd(a, full_matrices, compute_uv) + if not compute_uv: + # our C++ API always returns a full 3-tuple, but in this case numpy + # returns only S + return USV.S + + USV.V.t_() # pytorch returns the transposed Vt, while numpy returns V + return USV From a14c2fbdcca14a12e7ee2fd4538bfa252108a58e Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 29 Sep 2020 17:10:48 -0500 Subject: [PATCH 05/63] WIP: the comment inside _create_U_S_VT was simply wrong: lapack and magam expect VT to be in column-major order, so what we were passing around was effectively V, not VT. - Fix the comment and actually create a column-major version of VT, which we can return directly from linalg_svd - kill the .t() from python's version of linalg.svd(), since we are directly getting VT from C++ now - call VT.transpose_() inside the legacy version of svd, to keep backwards compatibilty This is WIP because autograd is still broken --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 20 ++++++++++++++++--- aten/src/ATen/native/LinearAlgebraUtils.h | 9 ++++++--- .../ATen/native/cuda/BatchLinearAlgebra.cu | 2 +- torch/linalg/__init__.py | 2 -- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 2e6948681bad..37cee8ab881d 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1030,7 +1030,7 @@ std::tuple _linalg_svd_helper_cpu(const Tensor& self, bo if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); @@ -1062,16 +1062,30 @@ std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tenso } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +/* the legacy version of torch.svd, implemented in terms of linalg_svd: note + the two main differences: + + 1. the 2nd parameter is bool some=True, which if effectively the opposite + of full_matrices=True + + 2. linalg_svd returns VT, while svd returns V. To accommodate the + difference, we transpose_() VT upon return +*/ std::tuple svd(const Tensor& self, bool some, bool compute_uv) { bool full_matrices = !some; - return at::linalg_svd(self, full_matrices, compute_uv); + Tensor U, S, VT; + std::tie(U, S, VT) = at::linalg_svd(self, full_matrices, compute_uv); + VT.transpose_(-2, -1); + return std::make_tuple(U, S, VT); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, const Tensor& self, bool some, bool compute_uv) { bool full_matrices = !some; - return at::linalg_svd_out(U, S, VT, self, full_matrices, compute_uv); + auto result = at::linalg_svd_out(U, S, VT, self, full_matrices, compute_uv); + VT.transpose_(-2, -1); + return result; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 5c07700f1e85..b5a0974ed439 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -241,18 +241,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } + // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - // VT should be a row-major or a batch of row-major matrices + strides = at::detail::defaultStrides(sizes); + strides[input.dim() - 1] = n; + strides[input.dim() - 2] = 1; Tensor VT_empty; if (!input.is_cuda()) { - VT_empty = at::empty(sizes, input.options()); + VT_empty = at::empty_strided(sizes, strides, input.options()); } else { // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd // (which is the driver routine for the divide and conquer SVD operation) // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that // moves the inputs between devices internally. - VT_empty = at::empty(sizes, input.options().device(at::kCPU)); + VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); } sizes.pop_back(); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 5cf918a15d4f..7f432b36cf95 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1412,7 +1412,7 @@ std::tuple _linalg_svd_helper_cuda(const Tensor& self, b if (compute_uv) { if (some) { - VT_working_copy = VT_working_copy.narrow(-1, 0, k); + VT_working_copy = VT_working_copy.narrow(-2, 0, k); } } else { VT_working_copy.zero_(); diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 6a97e449662e..39de6f0827e8 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -151,6 +151,4 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): # our C++ API always returns a full 3-tuple, but in this case numpy # returns only S return USV.S - - USV.V.t_() # pytorch returns the transposed Vt, while numpy returns V return USV From 675e12244c0b6c536b39f43051a9f9509f3a453a Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 30 Sep 2020 03:42:30 -0500 Subject: [PATCH 06/63] we can't use transpose_(), else autograd complains that the results have been modified --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 37cee8ab881d..bc3b42b6a9d2 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1049,6 +1049,8 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr return at::_linalg_svd_helper(self, full_matrices, compute_uv); } +// Question for reviewers: should this function even exist? Do we want to +// support out versions of torch.linalg.*? std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, const Tensor& self, bool full_matrices, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, @@ -1076,16 +1078,18 @@ std::tuple svd(const Tensor& self, bool some, bool compu bool full_matrices = !some; Tensor U, S, VT; std::tie(U, S, VT) = at::linalg_svd(self, full_matrices, compute_uv); - VT.transpose_(-2, -1); - return std::make_tuple(U, S, VT); + Tensor V = VT.transpose(-2, -1); + return std::make_tuple(U, S, V); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& VT, +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { - bool full_matrices = !some; - auto result = at::linalg_svd_out(U, S, VT, self, full_matrices, compute_uv); - VT.transpose_(-2, -1); - return result; + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::svd(self, some, compute_uv); + U.resize_as_(U_tmp).copy_(U_tmp); + S.resize_as_(S_tmp).copy_(S_tmp); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 91ec980acd0afdc451106c8e4534e6f582c59c01 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 30 Sep 2020 04:30:36 -0500 Subject: [PATCH 07/63] add a TODO --- torch/csrc/autograd/FunctionsManual.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 1948e67d2f6a..cdd9c463ebbc 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1655,6 +1655,8 @@ std::tuple prelu_double_backward( // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. + +// XXX TODO: this is wrong! It expects to receive "V" but linalg_svd now returns VT Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, bool full_matrices, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { TORCH_CHECK(compute_uv, From 5886d5802898cc51e7989abb65e9d67a82f13557 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 30 Sep 2020 04:39:39 -0500 Subject: [PATCH 08/63] use ! instead of not --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index bc3b42b6a9d2..09598d96505a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1005,7 +1005,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, } std::tuple _linalg_svd_helper_cpu(const Tensor& self, bool full_matrices, bool compute_uv) { - bool some = not full_matrices; + bool some = !full_matrices; std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); From f763c8d1ab2fa84446a690d857bdfd2cd7278562 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 9 Oct 2020 05:22:56 -0500 Subject: [PATCH 09/63] partially undo commit 3923d4eab7: keep at::svd as the main function and implement linalg_svd on top of it --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 56 ++++++++++--------- .../ATen/native/cuda/BatchLinearAlgebra.cu | 5 +- aten/src/ATen/native/native_functions.yaml | 15 +++-- tools/autograd/derivatives.yaml | 4 +- torch/csrc/autograd/FunctionsManual.cpp | 15 ++--- torch/csrc/autograd/FunctionsManual.h | 4 +- 6 files changed, 49 insertions(+), 50 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 09598d96505a..6141ba628f5d 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -935,7 +935,7 @@ std::tuple symeig_out(Tensor& vals, Tensor& vecs, const Tensor return std::tuple(vals, vecs); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, @@ -1004,8 +1004,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, #endif } -std::tuple _linalg_svd_helper_cpu(const Tensor& self, bool full_matrices, bool compute_uv) { - bool some = !full_matrices; +std::tuple _svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1040,52 +1039,55 @@ std::tuple _linalg_svd_helper_cpu(const Tensor& self, bo U_working_copy.zero_(); VT_working_copy.zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } -std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { +std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - return at::_linalg_svd_helper(self, full_matrices, compute_uv); + return at::_svd_helper(self, some, compute_uv); } -// Question for reviewers: should this function even exist? Do we want to -// support out versions of torch.linalg.*? -std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, - const Tensor& self, bool full_matrices, bool compute_uv) { +std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, + const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - Tensor U_tmp, S_tmp, VT_tmp; - std::tie(U_tmp, S_tmp, VT_tmp) = at::_linalg_svd_helper(self, full_matrices, compute_uv); + Tensor U_tmp, S_tmp, V_tmp; + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); - return std::tuple(U, S, VT); + V.resize_as_(V_tmp).copy_(V_tmp); + return std::tuple(U, S, V); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -/* the legacy version of torch.svd, implemented in terms of linalg_svd: note - the two main differences: +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/* torch.linalg.svd, implemented in terms of torch.svd. There are two main + differences: 1. the 2nd parameter is bool some=True, which if effectively the opposite of full_matrices=True - 2. linalg_svd returns VT, while svd returns V. To accommodate the - difference, we transpose_() VT upon return + 2. svd returns V, while linalg.svd returns VT. To accommodate the + difference, we transpose() V upon return */ -std::tuple svd(const Tensor& self, bool some, bool compute_uv) { - bool full_matrices = !some; - Tensor U, S, VT; - std::tie(U, S, VT) = at::linalg_svd(self, full_matrices, compute_uv); - Tensor V = VT.transpose(-2, -1); - return std::make_tuple(U, S, V); +std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + bool some = !full_matrices; + Tensor U, S, V; + std::tie(U, S, V) = at::svd(self, some, compute_uv); + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); } -std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, - const Tensor& self, bool some, bool compute_uv) { +// Question for reviewers: should this function even exist? Do we want to +// support out versions of torch.linalg.*? +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& V, + const Tensor& self, bool full_matrices, bool compute_uv) { Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::svd(self, some, compute_uv); + std::tie(U_tmp, S_tmp, V_tmp) = at::svd(self, full_matrices, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); V.resize_as_(V_tmp).copy_(V_tmp); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 7f432b36cf95..23cf1222b5c8 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1371,8 +1371,7 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _linalg_svd_helper_cuda(const Tensor& self, bool full_matrices, bool compute_uv) { - bool some = !full_matrices; +std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1423,6 +1422,8 @@ std::tuple _linalg_svd_helper_cuda(const Tensor& self, b S_working_copy = same_stride_to(S_working_copy, self.options()); VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f0210c474d34..d6e26f375cf3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5547,6 +5547,13 @@ use_c10_dispatcher: full variants: method, function +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: full + variants: function + dispatch: + CPU: _svd_helper_cpu + CUDA: _svd_helper_cuda + - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) - func: cholesky(Tensor self, bool upper=False) -> Tensor @@ -8020,14 +8027,6 @@ use_c10_dispatcher: full variants: method, function -- func: _linalg_svd_helper(Tensor self, bool full_matrices, bool compute_uv) -> (Tensor, Tensor, Tensor) - python_module: linalg - use_c10_dispatcher: full - variants: function - dispatch: - CPU: _linalg_svd_helper_cpu - CUDA: _linalg_svd_helper_cuda - ## Functions that are only for testing # It is undocumented and should not be used outside of tests. diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7f6e142ab95d..70ddaee5226f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1011,8 +1011,8 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - self: linalg_svd_backward(grads, self, full_matrices, compute_uv, U, S, V) +- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index cdd9c463ebbc..1e73ebac2a2a 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1655,15 +1655,12 @@ std::tuple prelu_double_backward( // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. - -// XXX TODO: this is wrong! It expects to receive "V" but linalg_svd now returns VT -Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, - bool full_matrices, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { +Tensor svd_backward(const std::vector &grads, const Tensor& self, + bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { TORCH_CHECK(compute_uv, - "linalg_svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", + "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); - bool some = !full_matrices; auto m = self.size(-2); auto n = self.size(-1); auto k = sigma.size(-1); @@ -1955,7 +1952,7 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); - return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); + return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { @@ -2005,7 +2002,7 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo std::tie(u, sigma, v) = self.svd(); // logdet = \sum log(sigma) auto gsigma = grad.unsqueeze(-1).div(sigma); - return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); + return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { @@ -2057,7 +2054,7 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, // so logabsdet = \sum log(abs(sigma)) // but det = 0, so backward logabsdet = \sum log(sigma) auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); - return linalg_svd_backward({{}, gsigma, {}}, self, false, true, u, sigma, v); + return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 6926448143f1..8fd0e9b08cc4 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -119,8 +119,8 @@ at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Te at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); -Tensor linalg_svd_backward(const std::vector &grads, const Tensor& self, - bool full_matrices, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); +Tensor svd_backward(const std::vector &grads, const Tensor& self, + bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v); std::tuple triangular_solve_backward( From b319768a29514fa6e51f689731aa302568df1341 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 9 Oct 2020 11:30:29 -0500 Subject: [PATCH 10/63] change the return type of linalg_svd(..., compute_uv=False): we return empty tensors in C++ and None in Python --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 10 ++++++++-- test/cpp/api/functional.cpp | 13 +++++++++++++ test/test_linalg.py | 15 +++++++++++---- torch/linalg/__init__.py | 9 ++++++--- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 6141ba628f5d..47ee8a35a31d 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1078,8 +1078,14 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr bool some = !full_matrices; Tensor U, S, V; std::tie(U, S, V) = at::svd(self, some, compute_uv); - Tensor VT = V.transpose(-2, -1); - return std::make_tuple(U, S, VT); + if (compute_uv) { + Tensor VT = V.transpose(-2, -1); + return std::make_tuple(U, S, VT); + } else { + Tensor empty_U = at::empty({0}, self.options()); + Tensor empty_VT = at::empty({0}, self.options()); + return std::make_tuple(empty_U, S, empty_VT); + } } // Question for reviewers: should this function even exist? Do we want to diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 4efdb122efc8..e5d1628bedad 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -2796,3 +2796,16 @@ TEST_F(FunctionalTest, BCEWithLogitsLoss) { ASSERT_TRUE(torch::isfinite(out2).all().item()); } } + +TEST_F(FunctionalTest, linalg_svd) { + // NOTE: this is only a partial test: it tests that when we pass + // compute_uv=False, the returned U and VT are empty tensors. We need to + // write a C++ test because in Python it has a slightly different behavior + // and it returns (None, S, None) instead. The full logic for svd is + // tested thoughtfully in Python. + const auto input = torch::rand({7, 3}); + torch::Tensor U, S, VT; + std::tie(U, S, VT) = at::linalg_svd(input, true, false); + ASSERT_EQ(U.numel(), 0) << "U is not empty"; + ASSERT_EQ(VT.numel(), 0) << "VT is not empty"; +} diff --git a/test/test_linalg.py b/test/test_linalg.py index 0b5fd200b310..87c8db7f884e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -576,11 +576,18 @@ def test_norm_fastpaths(self, device): def test_svd(self, device, dtype): t = torch.randn((10, 10), device=device, dtype=dtype) np_t = t.cpu().numpy() + for full_matrices in (True, False): - for compute_uv in (True, False): - expected = np.linalg.svd(np_t, full_matrices, compute_uv) - actual = torch.linalg.svd(t, full_matrices, compute_uv) - self.assertEqual(actual, expected) + expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) + actual = torch.linalg.svd(t, full_matrices, compute_uv=True) + self.assertEqual(actual, expected) + + for full_matrices in (True, False): + np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False) + assert USV.U is None + self.assertEqual(USV.S, np_s) + assert USV.V is None instantiate_device_type_tests(TestLinalg, globals()) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 39de6f0827e8..b4e553eee67d 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -148,7 +148,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): assert not hermitian # XXX implement me USV = _linalg.linalg_svd(a, full_matrices, compute_uv) if not compute_uv: - # our C++ API always returns a full 3-tuple, but in this case numpy - # returns only S - return USV.S + # we want to return a value of type torch.return_types.linalg_svd + # (which is a PyStructSequence). However, this type is not directly + # exposed by pytorch, so we get a reference to it by calling type() on + # USV. + USV_TYPE = type(USV) + return USV_TYPE((None, USV.S, None)) return USV From 1817211d9f5b33af4c3467e41efefbc6b741e1f6 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 9 Oct 2020 15:23:14 -0500 Subject: [PATCH 11/63] fix flake8 --- torch/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index b4e553eee67d..126884724c6e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -145,7 +145,7 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): """ XXX write docstring """ - assert not hermitian # XXX implement me + assert not hermitian # XXX implement me USV = _linalg.linalg_svd(a, full_matrices, compute_uv) if not compute_uv: # we want to return a value of type torch.return_types.linalg_svd From 97662da54aba366dd547e46d557edc2b86ff86dc Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 12 Oct 2020 10:34:06 -0500 Subject: [PATCH 12/63] fix the docstring of svd, according to the discussion in issue #45821 --- torch/_torch_docs.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 32806259df35..3c5dd8ecafbb 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7150,7 +7150,18 @@ def merge_dicts(*dicts): This function returns a namedtuple ``(U, S, V)`` which is the singular value decomposition of a input real matrix or batches of real matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`. +:math:`input = U \times diag(S) \times V^H`, where :math:`V^H` is the conjugate transpose +of ``V``. + +In Python, :math:`V^H` is computed by ``v.T.conj()``. Note that for +non-complex types, ``.conj()`` is a no-op and can be omittted. To summarize, +the original Tensor can be reconstructed by:: + + U @ diag(S) @ V.T.conj() # for real and complex numbers + U @ diag(S) @ V.T # only for real numbers + +The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of +``S`` is always real numbers, even if ``input`` is complex. If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned From 25caa20d57c9b1325b8d7f8c1a96393d13066f2f Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 15 Oct 2020 09:49:51 +0000 Subject: [PATCH 13/63] write the docstring for linalg.svd --- docs/source/linalg.rst | 1 + torch/linalg/__init__.py | 105 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 834b6a60ac93..aa0392007ddb 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -14,3 +14,4 @@ Functions .. autofunction:: det .. autofunction:: norm +.. autofunction:: svd diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 126884724c6e..9374b6cb3687 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -141,11 +141,108 @@ (tensor(3.7417), tensor(11.2250)) """) -def svd(a, full_matrices=True, compute_uv=True, hermitian=False): - """ - XXX write docstring + +def svd(a, full_matrices=True, compute_uv=True): + r""" +linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) + +This function returns a namedtuple ``(U, S, Vh)`` which is the singular value +decomposition of a input real matrix or batches of real matrices :attr:`input` such that +:math:`input = U \times diag(S) \times V^H` (where :math:`V^H` is ``Vh``). + +.. warning:: This function is similar to :meth:`~torch.svd`, but has the following important + differences to make it more compatible with ``numpy``: + + * :attr:`full_matrices` is effectively the opposite of + ``torch.svd(some=...)``. And the default is the opposite as well. + + * it returns ``Vh``, whereas :meth:`~torch.svd` returns + ``V``. The result is that when using :meth:`~torch.svd` you + need to manually transpose and conjugate ``V`` in order to + reconstruct the original matrix. + + * If :attr:`compute_uv=False`, it returns ``None`` for ``U`` and + ``V``, whereas :meth:`~torch.svd` returns zero-filled tensors. + + This function has also a difference w.r.t. numpy: + + * if :attr:`compute_uv=False` it returns ``(None, S, None)``, + whereas numpy returns ``S``. + + +The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of +``S`` is always real numbers, even if ``input`` is complex. + +If :attr:`full_matrices` is ``False``, the method returns the reduced singular value decomposition +i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned +`U` and `V` matrices will contain only :math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be None.:attr:`full_matrices` will +be ignored here. + +.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch is returned in descending order. + +.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine + `gesdd` as well. + +.. note:: Irrespective of the original strides, the returned matrix `U` + will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` + +.. note:: Extra care needs to be taken when backward through `U` and `V` + outputs. Such operation is really only stable when :attr:`input` is + full rank with all distinct singular values. Otherwise, ``NaN`` can + appear as the gradients are not properly defined. Also, notice that + double backward will usually do an additional backward through `U` and + `V` even if the original backward is only on `S`. + +.. note:: When :attr:`full_matrices` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` + and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors + can be arbitrary bases of the subspaces. + +.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` + from the forward pass is required for the backward operation. + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of :math:`m \times n` matrices. + full_matrices (bool, optional): controls the shape of returned `U` and `V` + compute_uv (bool, optional): option whether to compute `U` and `V` or not + out (tuple, optional): the output tuple of tensors + +Example:: + + >>> import torch + >>> a = torch.randn(5, 3) + >>> a + tensor([[-0.3357, -0.2987, -1.1096], + [ 1.4894, 1.0016, -0.4572], + [-1.9401, 0.7437, 2.0968], + [ 0.1515, 1.3812, 1.5491], + [-1.8489, -0.5907, -2.5673]]) + >>> + >>> # reconstruction in the full_matrices=False case + >>> u, s, vh = torch.linalg.svd(a, full_matrices=False) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # reconstruction in the full_matrices=True case + >>> u, s, vh = torch.linalg.svd(a) + >>> u.shape, s.shape, vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + >>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh) + tensor(1.0486e-06) + >>> + >>> # extra dimensions + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) + >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) + tensor(3.0957e-06) """ - assert not hermitian # XXX implement me USV = _linalg.linalg_svd(a, full_matrices, compute_uv) if not compute_uv: # we want to return a value of type torch.return_types.linalg_svd From f056875064b0fe9312adaa983207a20850a640fe Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 15 Oct 2020 09:55:19 +0000 Subject: [PATCH 14/63] fix for the complex case: torch.svd should return V but lapack computes Vh, so we need to transpose AND conjugate. OTOH, linalg.svd returns Vh, so nothing is needed in that case --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 24 +++++++++++++++---- .../ATen/native/cuda/BatchLinearAlgebra.cu | 5 +++- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_linalg.py | 4 ++-- test/test_torch.py | 9 +++++++ tools/autograd/derivatives.yaml | 4 ++-- 6 files changed, 37 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 47ee8a35a31d..8481a260f406 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1004,7 +1004,8 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, #endif } -std::tuple _svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { +std::tuple _svd_helper_cpu(const Tensor& self, bool some, + bool compute_uv, bool apply_conj) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1041,21 +1042,25 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. VT_working_copy.transpose_(-2, -1); + if (VT_working_copy.is_complex() && apply_conj) + VT_working_copy = VT_working_copy.conj(); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - return at::_svd_helper(self, some, compute_uv); + bool apply_conj = true; + return at::_svd_helper(self, some, compute_uv, apply_conj); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + bool apply_conj = true; Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv, apply_conj); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); V.resize_as_(V_tmp).copy_(V_tmp); @@ -1075,9 +1080,13 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, */ std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + bool some = !full_matrices; + bool apply_conj = false; Tensor U, S, V; - std::tie(U, S, V) = at::svd(self, some, compute_uv); + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv, apply_conj); if (compute_uv) { Tensor VT = V.transpose(-2, -1); return std::make_tuple(U, S, VT); @@ -1092,8 +1101,13 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr // support out versions of torch.linalg.*? std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool full_matrices, bool compute_uv) { + TORCH_CHECK(self.dim() >= 2, + "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + + bool some = !full_matrices; + bool apply_conj = false; Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::svd(self, full_matrices, compute_uv); + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, full_matrices, compute_uv, apply_conj); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); V.resize_as_(V_tmp).copy_(V_tmp); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 23cf1222b5c8..ae875786071f 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1371,7 +1371,8 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { +std::tuple _svd_helper_cuda(const Tensor& self, bool some, + bool compute_uv, bool apply_conj) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1424,6 +1425,8 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. VT_working_copy.transpose_(-2, -1); + if (VT_working_copy.is_complex() && apply_conj) + VT_working_copy = VT_working_copy.conj(); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d6e26f375cf3..d7b4edce0716 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5547,7 +5547,7 @@ use_c10_dispatcher: full variants: method, function -- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) +- func: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full variants: function dispatch: diff --git a/test/test_linalg.py b/test/test_linalg.py index 87c8db7f884e..de167a50a985 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -572,9 +572,9 @@ def test_norm_fastpaths(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) + @dtypes(torch.float, torch.double, torch.cfloat) def test_svd(self, device, dtype): - t = torch.randn((10, 10), device=device, dtype=dtype) + t = torch.randn((10, 11), device=device, dtype=dtype) np_t = t.cpu().numpy() for full_matrices in (True, False): diff --git a/test/test_torch.py b/test/test_torch.py index ee27c8dd65cf..a12e6fb5718a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9786,6 +9786,15 @@ def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **op guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () run_subtest(guess_rank, actual_rank, size, batches, device, jitted) + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + t2 = U @ torch.diag(S).type(dtype) @ V.T.conj() + self.assertEqual(t, t2) + def test_lerp(self, device): start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] for shapes in product(start_end_shapes, start_end_shapes): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 70ddaee5226f..6fa4279a43ae 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1011,8 +1011,8 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) - self: svd_backward(grads, self, some, compute_uv, U, S, V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) + self: svd_backward(grads, self, some, compute_uv, U, S, V) # XXX apply_conj - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) From edca9dfb231fcacd62adf5351e3ce1023b9dd357 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 16 Oct 2020 08:43:27 +0000 Subject: [PATCH 15/63] improve the docstring --- torch/linalg/__init__.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 9374b6cb3687..eea2395fc128 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -150,24 +150,25 @@ def svd(a, full_matrices=True, compute_uv=True): decomposition of a input real matrix or batches of real matrices :attr:`input` such that :math:`input = U \times diag(S) \times V^H` (where :math:`V^H` is ``Vh``). -.. warning:: This function is similar to :meth:`~torch.svd`, but has the following important - differences to make it more compatible with ``numpy``: +.. warning:: **Differences with** :meth:`~torch.svd`: - * :attr:`full_matrices` is effectively the opposite of - ``torch.svd(some=...)``. And the default is the opposite as well. + * :attr:`full_matrices` is the opposite of + :meth:`~torch.svd`'s :attr:`some`. Note that default value + for both is ``True``, so the default behavior is effectively + the opposite. - * it returns ``Vh``, whereas :meth:`~torch.svd` returns - ``V``. The result is that when using :meth:`~torch.svd` you - need to manually transpose and conjugate ``V`` in order to - reconstruct the original matrix. + * it returns ``Vh``, whereas :meth:`~torch.svd` returns + ``V``. The result is that when using :meth:`~torch.svd` you + need to manually transpose and conjugate ``V`` in order to + reconstruct the original matrix. - * If :attr:`compute_uv=False`, it returns ``None`` for ``U`` and - ``V``, whereas :meth:`~torch.svd` returns zero-filled tensors. + * If :attr:`compute_uv=False`, it returns ``None`` for ``U`` and + ``V``, whereas :meth:`~torch.svd` returns zero-filled tensors. - This function has also a difference w.r.t. numpy: + **Differences with** ``numpy.linalg.svd``: - * if :attr:`compute_uv=False` it returns ``(None, S, None)``, - whereas numpy returns ``S``. + * if :attr:`compute_uv=False` it returns ``(None, S, None)``, + whereas numpy returns ``S``. The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of From 0462f7937de5810f0c81ffd683cf4f0e967f00fc Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 16 Oct 2020 09:49:50 +0000 Subject: [PATCH 16/63] add comments to make sure we don't forget about this when we add support for complex autograd --- tools/autograd/derivatives.yaml | 5 ++++- torch/csrc/autograd/FunctionsManual.cpp | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 6fa4279a43ae..c9f2f6576467 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1011,8 +1011,11 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) +# NOTE: currently apply_conj is ignored, but it if used only when operating on +# complex, which are not supported by svd_backward so far. When we add support +# for complex, we need to remember to consider it. - name: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) - self: svd_backward(grads, self, some, compute_uv, U, S, V) # XXX apply_conj + self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 1e73ebac2a2a..8962d167bcb0 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1657,6 +1657,10 @@ std::tuple prelu_double_backward( // This makes no assumption on the signs of sigma. Tensor svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { + /* NOTE: currently this function does not support complex numbers. When we + * add support for it, we need to consider "bool apply_conj", which + * currently is ignored. See the corresponding comment in derivatives.yaml. + */ TORCH_CHECK(compute_uv, "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); From e900ad28487a2ca714b834cb99a2e90889fb104c Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 16 Oct 2020 10:09:14 +0000 Subject: [PATCH 17/63] add a test for cdouble but skip it, because it segfaults. Need to fill a separate issue --- test/test_linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index de167a50a985..220db0260777 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -572,8 +572,11 @@ def test_norm_fastpaths(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.float, torch.double, torch.cfloat) + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_svd(self, device, dtype): + if dtype is torch.cdouble: + # this test segfaults + self.skipTest('Issue XXX') t = torch.randn((10, 11), device=device, dtype=dtype) np_t = t.cpu().numpy() From e8ce28204fb178602928fd54082353a56297319d Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 16 Oct 2020 10:17:32 +0000 Subject: [PATCH 18/63] attach a docstring also to the underlying C function --- torch/linalg/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index eea2395fc128..6654e23db2bc 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -142,6 +142,8 @@ """) +_add_docstr(_linalg.linalg_svd, "See ``linalg.svd``") + def svd(a, full_matrices=True, compute_uv=True): r""" linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) From c7bfc0780a5b3aa279687feddf334ba40028131d Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 26 Oct 2020 06:10:29 -0500 Subject: [PATCH 19/63] this seems to be needed, else I get 'derivative for svd not implemented', but I don't really understand why --- aten/src/ATen/native/native_functions.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6f5e3611dcdd..06b0033ee17a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6140,14 +6140,14 @@ CUDA: legacy::cuda::_th_eig - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) - dispatch: - DefaultBackend: svd_out + # dispatch: + # DefaultBackend: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full variants: method, function - dispatch: - DefaultBackend: svd + # dispatch: + # DefaultBackend: svd - func: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full From 9bcf00b359add64d8a5f3ce003196ab90fe1df6e Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 26 Oct 2020 10:45:06 -0500 Subject: [PATCH 20/63] use dispatch: Math as per @mruberry suggestion --- aten/src/ATen/native/native_functions.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 06b0033ee17a..999e9fe8bd8e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6140,14 +6140,14 @@ CUDA: legacy::cuda::_th_eig - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) - # dispatch: - # DefaultBackend: svd_out + dispatch: + Math: svd_out - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full variants: method, function - # dispatch: - # DefaultBackend: svd + dispatch: + Math: svd - func: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full From 25752e6970489580d4b015f730c025eb339c6bf6 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 29 Oct 2020 10:40:34 +0000 Subject: [PATCH 21/63] implement the out= version of torch.linalg.svd --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 17 ++++-------- test/test_linalg.py | 17 ++++++++++++ torch/linalg/__init__.py | 29 ++++++++++++++------- 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 02418ca383c1..745d3f5b8d99 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1113,21 +1113,14 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr } } -// Question for reviewers: should this function even exist? Do we want to -// support out versions of torch.linalg.*? -std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& V, +std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, const Tensor& self, bool full_matrices, bool compute_uv) { - TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - - bool some = !full_matrices; - bool apply_conj = false; - Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, full_matrices, compute_uv, apply_conj); + Tensor U_tmp, S_tmp, VT_tmp; + std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); - V.resize_as_(V_tmp).copy_(V_tmp); - return std::tuple(U, S, V); + VT.resize_as_(VT_tmp).copy_(VT_tmp); + return std::tuple(U, S, VT); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_linalg.py b/test/test_linalg.py index b8d5acaa7956..fa074edf05a4 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -927,16 +927,33 @@ def test_svd(self, device, dtype): np_t = t.cpu().numpy() for full_matrices in (True, False): + # check linalg.svd agains numpy expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) actual = torch.linalg.svd(t, full_matrices, compute_uv=True) self.assertEqual(actual, expected) + # check linalg.svd agains linalg.svd(out=...) + out = (torch.empty_like(actual[0]), + torch.empty_like(actual[1]), + torch.empty_like(actual[2])) + out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out) + self.assertEqual(actual, out) + self.assertEqual(actual, out2) for full_matrices in (True, False): + # check linalg.svd agains numpy np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) USV = torch.linalg.svd(t, full_matrices, compute_uv=False) assert USV.U is None self.assertEqual(USV.S, np_s) assert USV.V is None + # check linalg.svd agains linalg.svd(out=...) + out_S = torch.empty_like(USV.S) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out_S) + assert USV.U is None + assert USV.S is out_S + self.assertEqual(USV.S, np_s) + assert USV.V is None + instantiate_device_type_tests(TestLinalg, globals()) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index cc244691ed20..30a7a5dd6273 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -143,7 +143,7 @@ _add_docstr(_linalg.linalg_svd, "See ``linalg.svd``") -def svd(a, full_matrices=True, compute_uv=True): +def svd(a, full_matrices=True, compute_uv=True, out=None): r""" linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) @@ -245,12 +245,21 @@ def svd(a, full_matrices=True, compute_uv=True): >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) tensor(3.0957e-06) """ - USV = _linalg.linalg_svd(a, full_matrices, compute_uv) - if not compute_uv: - # we want to return a value of type torch.return_types.linalg_svd - # (which is a PyStructSequence). However, this type is not directly - # exposed by pytorch, so we get a reference to it by calling type() on - # USV. - USV_TYPE = type(USV) - return USV_TYPE((None, USV.S, None)) - return USV + if compute_uv: + # easy case, same semantics as the C++ version of linalg_svd + return _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) + + # harder case: C++ returns (0-tensor, S, 0-tensor) but we want to return + # (None, S, None) instead + if out is not None: + if type(out) != torch.Tensor: + raise TypeError("linalg.svd: argument 'out' must be a Tensor when " + "compute_uv==False, not %s" % type(out).__name__) + out = (torch.Tensor(), out, torch.Tensor()) + USV = _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) + # we want to return a value of type torch.return_types.linalg_svd + # (which is a PyStructSequence). However, this type is not directly + # exposed by pytorch, so we get a reference to it by calling type() on + # USV. + USV_TYPE = type(USV) + return USV_TYPE((None, USV.S, None)) From 10a66d7076e06c09c06d8c28833ab0acc15944ec Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 29 Oct 2020 10:42:30 +0000 Subject: [PATCH 22/63] this no longer segfaults --- test/test_linalg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index fa074edf05a4..763a2f7419ed 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -920,9 +920,6 @@ def test_nuclear_norm_exceptions_old(self, device): @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_svd(self, device, dtype): - if dtype is torch.cdouble: - # this test segfaults - self.skipTest('Issue XXX') t = torch.randn((10, 11), device=device, dtype=dtype) np_t = t.cpu().numpy() From 6f480c29a70b9720362fff4f6f667d1b423130c6 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 10 Nov 2020 09:41:26 +0000 Subject: [PATCH 23/63] no longer needed --- test/test_linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 419b40a8cb17..16f94fc4d824 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1021,7 +1021,6 @@ def test_nuclear_norm_exceptions_old(self, device): # Tests torch.linalg.svd, vs. NumPy @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_svd(self, device, dtype): t = torch.randn((10, 11), device=device, dtype=dtype) From 4038acef10fbdb6713f020450f8b8b7cc8c6321a Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 10 Nov 2020 10:02:49 +0000 Subject: [PATCH 24/63] remove merge leftover --- test/test_linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 16f94fc4d824..65f434495830 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1243,7 +1243,6 @@ def test_vdot_invalid_args(self, device): def test_dot_invalid_args(self, device): self._test_dot_vdot_invalid_args(device, torch.dot) self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True) ->>>>>>> upstream/master instantiate_device_type_tests(TestLinalg, globals()) From a224dbd7875fb271ea54cd754b945991bacc22ed Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 10 Nov 2020 10:28:47 +0000 Subject: [PATCH 25/63] change the semantics of the out= param if compute_uv==False --- test/test_linalg.py | 3 ++- torch/linalg/__init__.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 65f434495830..326f4b2dbd08 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1048,7 +1048,8 @@ def test_svd(self, device, dtype): assert USV.V is None # check linalg.svd agains linalg.svd(out=...) out_S = torch.empty_like(USV.S) - USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out_S) + out = (torch.Tensor(), out_S, torch.Tensor()) + USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) assert USV.U is None assert USV.S is out_S self.assertEqual(USV.S, np_s) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 68249a0ea1c9..a78830bb95ae 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -210,7 +210,9 @@ def svd(a, full_matrices=True, compute_uv=True, out=None): batch dimensions consisting of :math:`m \times n` matrices. full_matrices (bool, optional): controls the shape of returned `U` and `V` compute_uv (bool, optional): option whether to compute `U` and `V` or not - out (tuple, optional): the output tuple of tensors + out (tuple, optional): the output tuple of tensors. If compute_uv==False, only the 2nd item is used. + The 1st and 3rd argument must be tensors, but they are ignored. E.g. you can + pass `(torch.Tensor(), out_S, torch.Tensor())` Example:: @@ -249,19 +251,13 @@ def svd(a, full_matrices=True, compute_uv=True, out=None): return _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) # harder case: C++ returns (0-tensor, S, 0-tensor) but we want to return - # (None, S, None) instead - if out is not None: - if type(out) != torch.Tensor: - raise TypeError("linalg.svd: argument 'out' must be a Tensor when " - "compute_uv==False, not %s" % type(out).__name__) - out = (torch.Tensor(), out, torch.Tensor()) + # (None, S, None) instead. Moreover, we want to return a value of type + # torch.return_types.linalg_svd (which is a PyStructSequence). However, + # this type is not directly exposed by pytorch, so we get a reference to + # it by calling type() on USV. USV = _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) - # we want to return a value of type torch.return_types.linalg_svd - # (which is a PyStructSequence). However, this type is not directly - # exposed by pytorch, so we get a reference to it by calling type() on - # USV. - USV_TYPE = type(USV) - return USV_TYPE((None, USV.S, None)) + USV_type = type(USV) + return USV_type((None, USV.S, None)) tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor From abf9baffc200ac6a459e78c70ce5e5c05594608e Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 10 Nov 2020 10:57:09 +0000 Subject: [PATCH 26/63] as discussed on the PR, remove the apply_conj feature: the risk of breaking existing code is too high. Instead, document the current behavior more accurately --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 14 ++++---------- .../src/ATen/native/cuda/BatchLinearAlgebra.cu | 5 +---- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_torch.py | 6 +++++- tools/autograd/derivatives.yaml | 5 +---- torch/_torch_docs.py | 18 ++++++++++++------ torch/csrc/autograd/FunctionsManual.cpp | 4 ---- 7 files changed, 24 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 91f9d41ebbf4..70cc15729ef3 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1021,8 +1021,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT, #endif } -std::tuple _svd_helper_cpu(const Tensor& self, bool some, - bool compute_uv, bool apply_conj) { +std::tuple _svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1059,25 +1058,21 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. VT_working_copy.transpose_(-2, -1); - if (VT_working_copy.is_complex() && apply_conj) - VT_working_copy = VT_working_copy.conj(); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - bool apply_conj = true; - return at::_svd_helper(self, some, compute_uv, apply_conj); + return at::_svd_helper(self, some, compute_uv); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); - bool apply_conj = true; Tensor U_tmp, S_tmp, V_tmp; - std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv, apply_conj); + std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); S.resize_as_(S_tmp).copy_(S_tmp); V.resize_as_(V_tmp).copy_(V_tmp); @@ -1101,9 +1096,8 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); bool some = !full_matrices; - bool apply_conj = false; Tensor U, S, V; - std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv, apply_conj); + std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); if (compute_uv) { Tensor VT = V.transpose(-2, -1); return std::make_tuple(U, S, VT); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 85eeeafc5c91..b03da5a5ef7b 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1654,8 +1654,7 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _svd_helper_cuda(const Tensor& self, bool some, - bool compute_uv, bool apply_conj) { +std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -1708,8 +1707,6 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. VT_working_copy.transpose_(-2, -1); - if (VT_working_copy.is_complex() && apply_conj) - VT_working_copy = VT_working_copy.conj(); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7dae66b44db9..b0a472775bb4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6204,7 +6204,7 @@ dispatch: Math: svd -- func: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) +- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) use_c10_dispatcher: full variants: function dispatch: diff --git a/test/test_torch.py b/test/test_torch.py index c9e1e2e33b0a..782f8b7ec6bb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10128,7 +10128,11 @@ def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **op def test_svd_complex(self, device, dtype): t = torch.randn((10, 10), dtype=dtype, device=device) U, S, V = torch.svd(t, some=False) - t2 = U @ torch.diag(S).type(dtype) @ V.T.conj() + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T self.assertEqual(t, t2) def test_lerp(self, device): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 6b9e69503ce9..f103c842a280 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1031,10 +1031,7 @@ - name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) -# NOTE: currently apply_conj is ignored, but it if used only when operating on -# complex, which are not supported by svd_backward so far. When we add support -# for complex, we need to remember to consider it. -- name: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V) +- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 6aec0e7d2859..c7afc997d7a7 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7643,15 +7643,21 @@ def merge_dicts(*dicts): This function returns a namedtuple ``(U, S, V)`` which is the singular value decomposition of a input real matrix or batches of real matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^H`, where :math:`V^H` is the conjugate transpose +:math:`input = U \times diag(S) \times V^T`, where :math:`V^T` is the transpose of ``V``. -In Python, :math:`V^H` is computed by ``v.T.conj()``. Note that for -non-complex types, ``.conj()`` is a no-op and can be omittted. To summarize, -the original Tensor can be reconstructed by:: +The original tensor can be reconstructed by:: + + U @ diag(S) @ V.T + + +.. note:: It is worth noting that that the code above works unmodified even + for complex numbers, i.e. the returned matrix ``V`` is already + conjugated. This behavior is probably unexpected from the + mathematical point of view, but it is not possible to change it + without breaking existing code. New code is encouraged to use + ``torch.linalg.svd`` instead, which returns :math:`V^H` instead. - U @ diag(S) @ V.T.conj() # for real and complex numbers - U @ diag(S) @ V.T # only for real numbers The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of ``S`` is always real numbers, even if ``input`` is complex. diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 0c1ba452299a..901c61b79058 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1777,10 +1777,6 @@ std::tuple prelu_double_backward( // This makes no assumption on the signs of sigma. Tensor svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { - /* NOTE: currently this function does not support complex numbers. When we - * add support for it, we need to consider "bool apply_conj", which - * currently is ignored. See the corresponding comment in derivatives.yaml. - */ TORCH_CHECK(compute_uv, "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); From 37ec294001181a77e1973d02e685f2f26214d6f0 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 12 Nov 2020 14:18:57 +0000 Subject: [PATCH 27/63] change again the semantics for compute_uv=False: we finally decided to return empty tensors, so there is no longer any need to use a python wrapper for adjust the return value --- test/test_linalg.py | 42 ++++++++++++++++++++++++++-------------- torch/linalg/__init__.py | 30 ++++++++-------------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index fd2b634224c5..ba0ac4a39960 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1019,20 +1019,22 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) - # Tests torch.linalg.svd, vs. NumPy @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - def test_svd(self, device, dtype): + def test_svd_compute_uv(self, device, dtype): + """ + Test the default case, compute_uv=True. Here we have the very same behavior as + numpy + """ t = torch.randn((10, 11), device=device, dtype=dtype) np_t = t.cpu().numpy() - for full_matrices in (True, False): - # check linalg.svd agains numpy + # check linalg.svd vs numpy expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) actual = torch.linalg.svd(t, full_matrices, compute_uv=True) self.assertEqual(actual, expected) - # check linalg.svd agains linalg.svd(out=...) + # check linalg.svd vs linalg.svd(out=...) out = (torch.empty_like(actual[0]), torch.empty_like(actual[1]), torch.empty_like(actual[2])) @@ -1040,21 +1042,33 @@ def test_svd(self, device, dtype): self.assertEqual(actual, out) self.assertEqual(actual, out2) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_svd_no_compute_uv(self, device, dtype): + """ + Test the compute_uv=False case. Here we have a different return type than + numpy: numpy returns S, we return (empty, S, empty) + """ + t = torch.randn((10, 11), device=device, dtype=dtype) + np_t = t.cpu().numpy() + def is_empty(x): + return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device + for full_matrices in (True, False): - # check linalg.svd agains numpy + # check linalg.svd vs numpy np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False) USV = torch.linalg.svd(t, full_matrices, compute_uv=False) - assert USV.U is None + assert is_empty(USV.U) self.assertEqual(USV.S, np_s) - assert USV.V is None - # check linalg.svd agains linalg.svd(out=...) - out_S = torch.empty_like(USV.S) - out = (torch.Tensor(), out_S, torch.Tensor()) + assert is_empty(USV.V) + # check linalg.svd vs linalg.svd(out=...) + out = (torch.Tensor(), torch.empty_like(USV.S), torch.Tensor()) USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) - assert USV.U is None - assert USV.S is out_S + assert USV.U is out[0] + assert USV.S is out[1] + assert USV.V is out[2] self.assertEqual(USV.S, np_s) - assert USV.V is None @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index a78830bb95ae..e5fe105436a1 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -140,10 +140,7 @@ (tensor(3.7417), tensor(11.2250)) """) -_add_docstr(_linalg.linalg_svd, "See ``linalg.svd``") - -def svd(a, full_matrices=True, compute_uv=True, out=None): - r""" +svd = _add_docstr(_linalg.linalg_svd, r""" linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) This function returns a namedtuple ``(U, S, Vh)`` which is the singular value @@ -162,12 +159,13 @@ def svd(a, full_matrices=True, compute_uv=True, out=None): need to manually transpose and conjugate ``V`` in order to reconstruct the original matrix. - * If :attr:`compute_uv=False`, it returns ``None`` for ``U`` and - ``V``, whereas :meth:`~torch.svd` returns zero-filled tensors. + * If :attr:`compute_uv=False`, it returns empty tensors (i.e., + with 0 elements) for ``U`` and ``V``, whereas + :meth:`~torch.svd` returns zero-filled tensors. **Differences with** ``numpy.linalg.svd``: - * if :attr:`compute_uv=False` it returns ``(None, S, None)``, + * if :attr:`compute_uv=False` it returns ``(empty_tensor, S, empty_tensor)``, whereas numpy returns ``S``. @@ -210,8 +208,8 @@ def svd(a, full_matrices=True, compute_uv=True, out=None): batch dimensions consisting of :math:`m \times n` matrices. full_matrices (bool, optional): controls the shape of returned `U` and `V` compute_uv (bool, optional): option whether to compute `U` and `V` or not - out (tuple, optional): the output tuple of tensors. If compute_uv==False, only the 2nd item is used. - The 1st and 3rd argument must be tensors, but they are ignored. E.g. you can + out (tuple, optional): the output tuple of tensors. If compute_uv=False, tThe 1st and 3rd + argument must be tensors, but they are ignored. E.g. you can pass `(torch.Tensor(), out_S, torch.Tensor())` Example:: @@ -245,19 +243,7 @@ def svd(a, full_matrices=True, compute_uv=True, out=None): >>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False) >>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh) tensor(3.0957e-06) - """ - if compute_uv: - # easy case, same semantics as the C++ version of linalg_svd - return _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) - - # harder case: C++ returns (0-tensor, S, 0-tensor) but we want to return - # (None, S, None) instead. Moreover, we want to return a value of type - # torch.return_types.linalg_svd (which is a PyStructSequence). However, - # this type is not directly exposed by pytorch, so we get a reference to - # it by calling type() on USV. - USV = _linalg.linalg_svd(a, full_matrices, compute_uv, out=out) - USV_type = type(USV) - return USV_type((None, USV.S, None)) +""") tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor From 60e463ff63d4254269b0cfccf552b3e67c81c5af Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 12 Nov 2020 14:37:59 +0000 Subject: [PATCH 28/63] fix flake8 --- test/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index ba0ac4a39960..41774ee267ac 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1052,6 +1052,7 @@ def test_svd_no_compute_uv(self, device, dtype): """ t = torch.randn((10, 11), device=device, dtype=dtype) np_t = t.cpu().numpy() + def is_empty(x): return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device From 243339931f3196d8129ff87d82943883c498d403 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 09:59:11 +0000 Subject: [PATCH 29/63] s/self/input in error messages --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 70cc15729ef3..37a05a8e4059 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1063,14 +1063,14 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor U_tmp, S_tmp, V_tmp; std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); @@ -1093,7 +1093,7 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); bool some = !full_matrices; Tensor U, S, V; From 4b4076dceabf7b827dd28f4496725ae0a55d827b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 10:01:00 +0000 Subject: [PATCH 30/63] kill this test, now the behavior is tested directly in python --- test/cpp/api/functional.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 14b29d78e7df..707c1bfd7ac0 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -2858,16 +2858,3 @@ TEST_F(FunctionalTest, BCEWithLogitsLoss) { ASSERT_TRUE(torch::isfinite(out2).all().item()); } } - -TEST_F(FunctionalTest, linalg_svd) { - // NOTE: this is only a partial test: it tests that when we pass - // compute_uv=False, the returned U and VT are empty tensors. We need to - // write a C++ test because in Python it has a slightly different behavior - // and it returns (None, S, None) instead. The full logic for svd is - // tested thoughtfully in Python. - const auto input = torch::rand({7, 3}); - torch::Tensor U, S, VT; - std::tie(U, S, VT) = at::linalg_svd(input, true, false); - ASSERT_EQ(U.numel(), 0) << "U is not empty"; - ASSERT_EQ(VT.numel(), 0) << "VT is not empty"; -} From 5fa4efe3290f72efa4029373ff31fa3f89996944 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 10:38:05 +0000 Subject: [PATCH 31/63] move the svd tests from test_torch.py to test_linalg.py --- test/test_linalg.py | 215 +++++++++++++++++++++++++++++++++++++++++++- test/test_torch.py | 206 ------------------------------------------ 2 files changed, 213 insertions(+), 208 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 41774ee267ac..bbfd54b23156 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4,6 +4,7 @@ import warnings from math import inf, nan, isnan from random import randrange +from itertools import product from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) @@ -1019,10 +1020,220 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + # ~~~ tests for torch.svd ~~~ + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_svd(self, device, dtype): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.tensor((), dtype=dtype, device=device) + outs = torch.tensor((), dtype=dtype, device=device) + outv = torch.tensor((), dtype=dtype, device=device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, msg='Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') + self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') + + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') + self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') + self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') + + # test non-contiguous + x = torch.randn(*dims, dtype=dtype, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, msg='Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') + self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') + + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_svd_no_singularvectors(self, device): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, msg="Singular values don't match") + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_svd_lowrank(self, device): + import torch + from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix + + dtype = torch.double + + def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): + density = options.pop('density', 1) + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + if density == 1: + a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) + a = a_input + else: + assert batches == () + a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) + a = a_input.to_dense() + + q = min(*size) + u, s, v = svd_lowrank(a_input, q=q, **options) + + # check if u, s, v is a SVD + u, s, v = u[..., :q], s[..., :q], v[..., :q] + A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) + self.assertEqual(A, a) + + # check if svd_lowrank produces same singular values as torch.svd + U, S, V = torch.svd(a) + self.assertEqual(s.shape, S.shape) + self.assertEqual(u.shape, U.shape) + self.assertEqual(v.shape, V.shape) + self.assertEqual(s, S) + + if density == 1: + # actual_rank is known only for dense inputs + # + # check if pairs (u, U) and (v, V) span the same + # subspaces, respectively + u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] + U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] + self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) + + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size, all_batches in [ + (2, (17, 4), all_batches), + (4, (17, 4), all_batches), + (4, (17, 17), all_batches), + (10, (100, 40), all_batches), + (7, (1000, 1000), [()]), + ]: + # dense input + for batches in all_batches: + run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) + if size != size[::-1]: + run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) + + # sparse input + for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: + for density in [0.005, 0.1]: + run_subtest(None, size, (), device, torch.svd_lowrank, density=density) + + # jitting support + jitted = torch.jit.script(torch.svd_lowrank) + actual_rank, size, batches = 2, (17, 4), () + run_subtest(actual_rank, size, batches, device, jitted) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.cfloat) + def test_svd_complex(self, device, dtype): + t = torch.randn((10, 10), dtype=dtype, device=device) + U, S, V = torch.svd(t, some=False) + # note: from the math point of view, it is weird that we need to use + # V.T instead of V.T.conj(): torch.svd has a buggy behavior for + # complex numbers and it's deprecated. You should use torch.linalg.svd + # instead. + t2 = U @ torch.diag(S).type(dtype) @ V.T + self.assertEqual(t, t2) + + def _test_svd_helper(self, shape, some, col_maj, device, dtype): + cpu_tensor = torch.randn(shape, device='cpu').to(dtype) + device_tensor = cpu_tensor.to(device=device) + if col_maj: + cpu_tensor = cpu_tensor.t() + device_tensor = device_tensor.t() + cpu_result = torch.svd(cpu_tensor, some=some) + device_result = torch.svd(device_tensor, some=some) + m = min(cpu_tensor.shape[-2:]) + # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). + # - When some==False, U[..., m:] can be arbitrary. + # - When some==True, U shape: [..., m], V shape: [m, m] + # - Signs are not deterministic. If the sign of a column of U is changed + # then the corresponding column of the V has to be changed. + # Thus here we only compare result[..., :m].abs() from CPU and device. + for x, y in zip(cpu_result, device_result): + self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_svd_square(self, device, dtype): + self._test_svd_helper((10, 10), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_svd_square_col_maj(self, device, dtype): + self._test_svd_helper((10, 10), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_svd_tall_some(self, device, dtype): + self._test_svd_helper((20, 5), True, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_svd_tall_all(self, device, dtype): + self._test_svd_helper((20, 5), False, False, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_svd_tall_some_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), True, True, device, dtype) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + def test_svd_tall_all_col_maj(self, device, dtype): + self._test_svd_helper((5, 20), False, True, device, dtype) + + # ~~~ tests for torch.linalg.svd ~~~ + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - def test_svd_compute_uv(self, device, dtype): + def test_linalg_svd_compute_uv(self, device, dtype): """ Test the default case, compute_uv=True. Here we have the very same behavior as numpy @@ -1045,7 +1256,7 @@ def test_svd_compute_uv(self, device, dtype): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) - def test_svd_no_compute_uv(self, device, dtype): + def test_linalg_svd_no_compute_uv(self, device, dtype): """ Test the compute_uv=False case. Here we have a different return type than numpy: numpy returns S, we return (empty, S, empty) diff --git a/test/test_torch.py b/test/test_torch.py index 747b4331ba30..25134d04b530 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9813,145 +9813,6 @@ def test_symeig_complex_xfailed(self, device, dtype): x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute.to(dtype))), outv.transpose(-2, -1).conj()) self.assertEqual(x, x_recon, atol=1e-8, rtol=0) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_pca_lowrank(self, device): @@ -10023,19 +9884,6 @@ def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **op guess_rank, actual_rank, size, batches = 2, 2, (17, 4), () run_subtest(guess_rank, actual_rank, size, batches, device, jitted) - @onlyCPU - @skipCPUIfNoLapack - @dtypes(torch.cfloat) - def test_svd_complex(self, device, dtype): - t = torch.randn((10, 10), dtype=dtype, device=device) - U, S, V = torch.svd(t, some=False) - # note: from the math point of view, it is weird that we need to use - # V.T instead of V.T.conj(): torch.svd has a buggy behavior for - # complex numbers and it's deprecated. You should use torch.linalg.svd - # instead. - t2 = U @ torch.diag(S).type(dtype) @ V.T - self.assertEqual(t, t2) - def test_lerp(self, device): start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] for shapes in product(start_end_shapes, start_end_shapes): @@ -21135,60 +20983,6 @@ def test(self, device, dtype): class TestTensorDeviceOps(TestCase): exact_dtype = True - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*(_float_types_no_half + _complex_types)) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*_float_types_no_half) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - class TestTorchMathOps(TestCase): exact_dtype = True From f16f6e75c559b4df39ef61cfce31e3b0793636b5 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 10:43:18 +0000 Subject: [PATCH 32/63] improve the docs of torch.svd --- torch/_torch_docs.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4ae3f4470557..89cdebb84fa7 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7669,17 +7669,8 @@ def merge_dicts(*dicts): U @ diag(S) @ V.T - -.. note:: It is worth noting that that the code above works unmodified even - for complex numbers, i.e. the returned matrix ``V`` is already - conjugated. This behavior is probably unexpected from the - mathematical point of view, but it is not possible to change it - without breaking existing code. New code is encouraged to use - ``torch.linalg.svd`` instead, which returns :math:`V^H` instead. - - -The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of -``S`` is always real numbers, even if ``input`` is complex. +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of From a008c530f8591f6763b847dfd8d85cbbf44242ba Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 17:07:19 +0000 Subject: [PATCH 33/63] try to improve the docs --- torch/linalg/__init__.py | 76 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index e5fe105436a1..a8350dc1bb53 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -141,11 +141,26 @@ """) svd = _add_docstr(_linalg.linalg_svd, r""" -linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) +linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -This function returns a namedtuple ``(U, S, Vh)`` which is the singular value -decomposition of a input real matrix or batches of real matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^H` (where :math:`V^H` is ``Vh``). + +Compute the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, Vh)``, such that :math:`input = U \times diag(S) \times +Vh`. If the inputs are batches, then returns batched outputs for all of ``U``, +``S`` and ``Vh``. + +If :attr:`full_matrices` is ``False``, the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `Vh` will be empy +tensors with no elements and the same device as :attr:`input`. The +:attr:`full_matrices` argument has no effect when :attr:`compute_uv` is False. + +The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +always be real-valued, even if :attr:`input` is complex. .. warning:: **Differences with** :meth:`~torch.svd`: @@ -156,28 +171,15 @@ * it returns ``Vh``, whereas :meth:`~torch.svd` returns ``V``. The result is that when using :meth:`~torch.svd` you - need to manually transpose and conjugate ``V`` in order to - reconstruct the original matrix. - - * If :attr:`compute_uv=False`, it returns empty tensors (i.e., - with 0 elements) for ``U`` and ``V``, whereas - :meth:`~torch.svd` returns zero-filled tensors. - - **Differences with** ``numpy.linalg.svd``: + need to manually transpose ``V`` in order to reconstruct the + original matrix. - * if :attr:`compute_uv=False` it returns ``(empty_tensor, S, empty_tensor)``, - whereas numpy returns ``S``. + * If :attr:`compute_uv=False`, it returns empty tensors for ``U`` + and ``Vh``, whereas :meth:`~torch.svd` returns zero-filled + tensors. - -The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of -``S`` is always real numbers, even if ``input`` is complex. - -If :attr:`full_matrices` is ``False``, the method returns the reduced singular value decomposition -i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned -`U` and `V` matrices will contain only :math:`min(n, m)` orthonormal columns. - -If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be None.:attr:`full_matrices` will -be ignored here. +.. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of + three tensors, even when :attr:`compute_uv=False`. .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -186,30 +188,28 @@ algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `Vh` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`full_matrices` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + + Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - full_matrices (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not - out (tuple, optional): the output tuple of tensors. If compute_uv=False, tThe 1st and 3rd - argument must be tensors, but they are ignored. E.g. you can + full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. + out (tuple, optional): the output tuple of tensors. If compute_uv=False, the 1st and 3rd + arguments must be tensors, but they are ignored. E.g. you can pass `(torch.Tensor(), out_S, torch.Tensor())` Example:: From 989505f085da82312cae5d2b68c231cf4ae4953b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 26 Nov 2020 17:18:42 +0000 Subject: [PATCH 34/63] rephrase --- torch/linalg/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index a8350dc1bb53..e86629776289 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -208,8 +208,8 @@ full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and consequently the shape of returned ``U`` and ``V``. Defaults to True. compute_uv (bool, optional): whether to compute `U` and `V` or not. Defaults to True. - out (tuple, optional): the output tuple of tensors. If compute_uv=False, the 1st and 3rd - arguments must be tensors, but they are ignored. E.g. you can + out (tuple, optional): a tuple of three tensors to use for the outputs. If compute_uv=False, + the 1st and 3rd arguments must be tensors, but they are ignored. E.g. you can pass `(torch.Tensor(), out_S, torch.Tensor())` Example:: From 80e18d20eee07f20129a5c5e8db0b72a4844220c Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 27 Nov 2020 17:29:25 +0000 Subject: [PATCH 35/63] fix --- third_party/fbgemm | 2 +- third_party/nccl/nccl | 2 +- torch/linalg/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/fbgemm b/third_party/fbgemm index 92c5f37b430a..8eb6dcb23eee 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 92c5f37b430a66905bd03514c510ee236aca7cc0 +Subproject commit 8eb6dcb23eee3b21c0e093a27810cb4a62dd3e27 diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 033d799524fb..31b5bb6f6447 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 033d799524fb97629af5ac2f609de367472b2696 +Subproject commit 31b5bb6f6447da98b9110c605465f9c09621074e diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index e86629776289..f1efb46d2709 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -194,7 +194,7 @@ .. note:: Gradients computed using `U` and `Vh` may be unstable if :attr:`input` is not full rank or has non-unique singular values. -.. note:: When :attr:`full_matrices` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` +.. note:: When :attr:`full_matrices` = ``True``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. From c9c60c2b55d5402723775adfc965ddb3f7c8aae2 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 7 Dec 2020 17:05:28 +0000 Subject: [PATCH 36/63] improve the docs of torch.svd in the same way we did for torch.linalg.svd. Add a warning saying that torch.svd is deprecated, and move the box which highlights between the two to torch.svd --- torch/_torch_docs.py | 64 +++++++++++++++++++++++++--------------- torch/linalg/__init__.py | 24 ++++----------- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 10887f74c986..c5e3c5071b1e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7999,26 +7999,47 @@ def merge_dicts(*dicts): r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -This function returns a namedtuple ``(U, S, V)`` which is the singular value -decomposition of a input real matrix or batches of real matrices :attr:`input` such that -:math:`input = U \times diag(S) \times V^T`, where :math:`V^T` is the transpose -of ``V``. +Compute the singular value decomposition of either a matrix or batch of +matrices :attr:`input`." The singular value decomposition is represented as a +namedtuple ``(U, S, V)``, such that :math:`input = U \times diag(S) \times +V^T`, where :math:`V^T` is the transpose of ``V``. If the inputs are batches, then +returns batched outputs for all of ``U``, `S`` and ``V``. The original tensor can be reconstructed by:: U @ diag(S) @ V.T +If :attr:`some` is ``True`` (default), the method returns the reduced singular +value decomposition i.e., if the last two dimensions of :attr:`input` are +``m`` and ``n``, then the returned `U` and `V` matrices will contain only +:math:`min(n, m)` orthonormal columns. + +If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be +zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +respectively, and the same device as :attr:`input`. The :attr:`some` +argument has no effect when :attr:`compute_uv` is False. + The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will always be real-valued, even if :attr:`input` is complex. -If :attr:`some` is ``True`` (default), the method returns the reduced -singular value decomposition i.e., if the last two dimensions of -:attr:`input` are ``m`` and ``n``, then the returned `U` matrix will -contain only :math:`min(n, m)` orthonormal columns and the size of `V` -will be :math:`(*, n, n)`. +.. warning:: ``torch.svd`` is deprecated. Please use :meth:`~torch.linalg.svd` + instead, which provides a better compatibility with + ``numpy.linalg.svd``. + +.. note:: **Differences with** :meth:`~torch.linalg.svd`: -If :attr:`compute_uv` is ``False``, the returned `U` and `V` matrices will be zero matrices -of shape :math:`(m \times m)` and :math:`(n \times n)` respectively. :attr:`some` will be ignored here. + * :attr:`some` is the opposite of :meth:`~torch.linalg.svd`'s + :attr:`full_matricies`. Note that default value for both is + ``True``, so the default behavior is effectively the opposite. + + * it returns ``V``, whereas :meth:`~torch.linalg.svd` returns + ``Vh``. The result is that when using :meth:`~torch.svd` you + need to manually transpose ``V`` in order to reconstruct the + original matrix. + + * If :attr:`compute_uv=False`, it returns zero-filled tensors for + ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + empty tensors. .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -8027,28 +8048,25 @@ def merge_dicts(*dicts): algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine `gesdd` as well. -.. note:: Irrespective of the original strides, the returned matrix `U` - will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()` +.. note:: The returned matrix `U` will be transposed, i.e. with strides + :code:`U.contiguous().transpose(-2, -1).stride()`. -.. note:: Extra care needs to be taken when backward through `U` and `V` - outputs. Such operation is really only stable when :attr:`input` is - full rank with all distinct singular values. Otherwise, ``NaN`` can - appear as the gradients are not properly defined. Also, notice that - double backward will usually do an additional backward through `U` and - `V` even if the original backward is only on `S`. +.. note:: Gradients computed using `U` and `V` may be unstable if + :attr:`input` is not full rank or has non-unique singular values. .. note:: When :attr:`some` = ``False``, the gradients on :code:`U[..., :, min(m, n):]` and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors can be arbitrary bases of the subspaces. -.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V` - from the forward pass is required for the backward operation. +.. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. + Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. - some (bool, optional): controls the shape of returned `U` and `V` - compute_uv (bool, optional): option whether to compute `U` and `V` or not + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently the shape of returned ``U`` and ``V``. Defaults to True. + compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: out (tuple, optional): the output tuple of tensors diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 455b7da09fae..5932edb67029 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -403,14 +403,17 @@ svd = _add_docstr(_linalg.linalg_svd, r""" linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) - Compute the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a namedtuple ``(U, S, Vh)``, such that :math:`input = U \times diag(S) \times Vh`. If the inputs are batches, then returns batched outputs for all of ``U``, ``S`` and ``Vh``. -If :attr:`full_matrices` is ``False``, the method returns the reduced singular +The original tensor can be reconstructed by:: + + U @ diag(S) @ Vh + +If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned `U` and `V` matrices will contain only :math:`min(n, m)` orthonormal columns. @@ -422,22 +425,6 @@ The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will always be real-valued, even if :attr:`input` is complex. -.. warning:: **Differences with** :meth:`~torch.svd`: - - * :attr:`full_matrices` is the opposite of - :meth:`~torch.svd`'s :attr:`some`. Note that default value - for both is ``True``, so the default behavior is effectively - the opposite. - - * it returns ``Vh``, whereas :meth:`~torch.svd` returns - ``V``. The result is that when using :meth:`~torch.svd` you - need to manually transpose ``V`` in order to reconstruct the - original matrix. - - * If :attr:`compute_uv=False`, it returns empty tensors for ``U`` - and ``Vh``, whereas :meth:`~torch.svd` returns zero-filled - tensors. - .. note:: Unlike NumPy's ``linalg.svd``, this always returns a namedtuple of three tensors, even when :attr:`compute_uv=False`. @@ -461,7 +448,6 @@ .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. - Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions consisting of :math:`m \times n` matrices. From 7b70521510013be9c60c3d797865425912745de1 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 7 Dec 2020 18:03:39 +0000 Subject: [PATCH 37/63] kill duplicate tests: what happened is that both upstream/master and this branch moved svd tests from test_torch to test_linalg and git merge didn't realzie, so we ended up having two copies of each of these --- test/test_linalg.py | 205 ++------------------------------------------ 1 file changed, 6 insertions(+), 199 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index b392192b769c..1e9831c4d1fb 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2011,37 +2011,37 @@ def _test_svd_helper(self, shape, some, col_maj, device, dtype): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypes(*floating_and_complex_types()) def test_svd_square(self, device, dtype): self._test_svd_helper((10, 10), True, False, device, dtype) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) + @dtypes(*floating_types()) def test_svd_square_col_maj(self, device, dtype): self._test_svd_helper((10, 10), True, True, device, dtype) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) + @dtypes(*floating_types()) def test_svd_tall_some(self, device, dtype): self._test_svd_helper((20, 5), True, False, device, dtype) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) + @dtypes(*floating_types()) def test_svd_tall_all(self, device, dtype): self._test_svd_helper((20, 5), False, False, device, dtype) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) + @dtypes(*floating_types()) def test_svd_tall_some_col_maj(self, device, dtype): self._test_svd_helper((5, 20), True, True, device, dtype) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float, torch.double) + @dtypes(*floating_types()) def test_svd_tall_all_col_maj(self, device, dtype): self._test_svd_helper((5, 20), False, True, device, dtype) @@ -4525,60 +4525,6 @@ def test_solve_methods_arg_device(self, device): "Expected LU_pivots and LU_data to be on the same device"): torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - def _test_svd_helper(self, shape, some, col_maj, device, dtype): - cpu_tensor = torch.randn(shape, device='cpu').to(dtype) - device_tensor = cpu_tensor.to(device=device) - if col_maj: - cpu_tensor = cpu_tensor.t() - device_tensor = device_tensor.t() - cpu_result = torch.svd(cpu_tensor, some=some) - device_result = torch.svd(device_tensor, some=some) - m = min(cpu_tensor.shape[-2:]) - # torch.svd returns torch.return_types.svd which is a tuple of (U, V, S). - # - When some==False, U[..., m:] can be arbitrary. - # - When some==True, U shape: [..., m], V shape: [m, m] - # - Signs are not deterministic. If the sign of a column of U is changed - # then the corresponding column of the V has to be changed. - # Thus here we only compare result[..., :m].abs() from CPU and device. - for x, y in zip(cpu_result, device_result): - self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_square(self, device, dtype): - self._test_svd_helper((10, 10), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_square_col_maj(self, device, dtype): - self._test_svd_helper((10, 10), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some(self, device, dtype): - self._test_svd_helper((20, 5), True, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all(self, device, dtype): - self._test_svd_helper((20, 5), False, False, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_some_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), True, True, device, dtype) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(*floating_types()) - def test_svd_tall_all_col_maj(self, device, dtype): - self._test_svd_helper((5, 20), False, True, device, dtype) - @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -5520,145 +5466,6 @@ def run_test(dims, eigenvectors, upper): for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)): run_test((5,) + batch_dims, eigenvectors, upper) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_svd(self, device, dtype): - def run_test(dims, some, compute_uv): - x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) - - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, msg='Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), msg='U not zero') - self.assertEqual(outv, torch.zeros_like(outv), msg='V not zero') - - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, msg='outputs of svd and svd with out differ') - self.assertEqual(ress, outs, msg='outputs of svd and svd with out differ') - self.assertEqual(resv, outv, msg='outputs of svd and svd with out differ') - - # test non-contiguous - x = torch.randn(*dims, dtype=dtype, device=device) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, msg='Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero') - self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero') - - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in itertools.product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): - for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_svd_lowrank(self, device): - import torch - from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - - dtype = torch.double - - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): - density = options.pop('density', 1) - if isinstance(matrix_size, int): - rows = columns = matrix_size - else: - rows, columns = matrix_size - if density == 1: - a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype) - a = a_input - else: - assert batches == () - a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype) - a = a_input.to_dense() - - q = min(*size) - u, s, v = svd_lowrank(a_input, q=q, **options) - - # check if u, s, v is a SVD - u, s, v = u[..., :q], s[..., :q], v[..., :q] - A = u.matmul(s.diag_embed()).matmul(v.transpose(-2, -1)) - self.assertEqual(A, a) - - # check if svd_lowrank produces same singular values as torch.svd - U, S, V = torch.svd(a) - self.assertEqual(s.shape, S.shape) - self.assertEqual(u.shape, U.shape) - self.assertEqual(v.shape, V.shape) - self.assertEqual(s, S) - - if density == 1: - # actual_rank is known only for dense inputs - # - # check if pairs (u, U) and (v, V) span the same - # subspaces, respectively - u, s, v = u[..., :actual_rank], s[..., :actual_rank], v[..., :actual_rank] - U, S, V = U[..., :actual_rank], S[..., :actual_rank], V[..., :actual_rank] - self.assertEqual(u.transpose(-2, -1).matmul(U).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - self.assertEqual(v.transpose(-2, -1).matmul(V).det().abs(), torch.ones(batches, device=device, dtype=dtype)) - - all_batches = [(), (1,), (3,), (2, 3)] - for actual_rank, size, all_batches in [ - (2, (17, 4), all_batches), - (4, (17, 4), all_batches), - (4, (17, 17), all_batches), - (10, (100, 40), all_batches), - (7, (1000, 1000), [()]), - ]: - # dense input - for batches in all_batches: - run_subtest(actual_rank, size, batches, device, torch.svd_lowrank) - if size != size[::-1]: - run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank) - - # sparse input - for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]: - for density in [0.005, 0.1]: - run_subtest(None, size, (), device, torch.svd_lowrank, density=density) - - # jitting support - jitted = torch.jit.script(torch.svd_lowrank) - actual_rank, size, batches = 2, (17, 4), () - run_subtest(actual_rank, size, batches, device, jitted) - @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_pca_lowrank(self, device): From e25de982255537525925c1050378d7f583cee8aa Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 8 Dec 2020 14:55:23 +0000 Subject: [PATCH 38/63] refactor and fix test_namedtuple_return_api.py 1. add support for _svd_helper 2. refactor the logic: now it uses small helper functions to reduce a bit of the code duplication, and I think that the new code is much easier to read and understand 3. as a proof of (2), fix two bugs: the old code did not check the names of "ret1" in the "linalg_" case, and in the "op.hasout" case it checked "ret" instead of "ret1" --- test/test_namedtuple_return_api.py | 44 ++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 88a01e48b5f2..ed441b31358a 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -12,7 +12,7 @@ all_operators_with_namedtuple_return = { 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', - 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh' + 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', '_svd_helper', 'linalg_svd', } @@ -56,7 +56,7 @@ def test_namedtuple_return(self): names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), - op(operators=['svd'], input=(), names=('U', 'S', 'V'), hasout=True), + op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True), op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False), op(operators=['qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True), @@ -67,21 +67,35 @@ def test_namedtuple_return(self): op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), ] + def get_func(f): + "Return either torch.f or torch.linalg.f, where 'f' is a string" + if f.startswith('linalg_'): + return getattr(torch.linalg, f[7:]) + return getattr(torch, f, None) + + def check_namedtuple(tup, names): + "Check that the namedtuple 'tup' has the given names" + for i, name in enumerate(names): + self.assertIs(getattr(tup, name), tup[i]) + for op in operators: for f in op.operators: - if 'linalg_' in f: - ret = getattr(torch.linalg, f[7:])(a, *op.input) - ret1 = getattr(torch.linalg, f[7:])(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - else: - ret = getattr(a, f)(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) + # 1. check the namedtuple returned by calling torch.f + func = get_func(f) + if func: + ret1 = func(a, *op.input) + check_namedtuple(ret1, op.names) + # + # 2. check the out= variant, if it exists + if func and op.hasout: + ret2 = func(a, *op.input, out=tuple(ret1)) + check_namedtuple(ret2, op.names) + # + # 3. check the Tensor.f method, if it exists + meth = getattr(a, f, None) + if meth: + ret3 = meth(*op.input) + check_namedtuple(ret3, op.names) all_covered_operators = set([x for y in operators for x in y.operators]) From 5dc359db1893633da1410eed49643832c492eaab Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 9 Dec 2020 10:17:52 +0000 Subject: [PATCH 39/63] mark the changes to _svd_helper as intentional --- test/backward_compatibility/check_backward_compatibility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index ccb4a6457537..5891747a5f4f 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -187,6 +187,7 @@ ("aten::ifft", datetime.date(2021, 1, 31)), ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), + ("aten::_svd_helper", datetime.date(2021, 1, 31)), ] def allow_listed(schema, allow_list): From 8b30bcbbd0c4ed5864afd16d009c79164516f019 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 10 Dec 2020 10:04:06 +0000 Subject: [PATCH 40/63] Improve the error message generated by test_overrides. Example of the OLD error message: E AssertionError: The following functions are not tested for __torch_function__ support, please ensure there is an entry in the dict returned by torch._overrides.get_testing_overrides for this function or if a __torch_function__ override does not make sense, add an entry to the tuple returned by torch._overrides.get_ignored_functions. E E [".svd", E ".svd", E ".svd", E ".linalg_svd"] E assert 4 == 0 E + where 4 = len([".svd", ".svd", ".linalg_svd"]) Example of the NEW error message: E AssertionError: The following functions are not tested for __torch_function__ E support, please ensure there is an entry in the dict returned by E torch._overrides.get_testing_overrides for this function or if a E __torch_function__ override does not make sense, add an entry to E the tuple returned by torch._overrides.get_ignored_functions. E E ['torch.svd', 'torch.svd', 'Tensor.svd', 'torch.linalg.linalg_svd'] E assert 4 == 0 E + where 4 = len(['torch.svd', 'torch.svd', 'Tensor.svd', 'torch.linalg.linalg_svd']) --- test/test_overrides.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_overrides.py b/test/test_overrides.py index 95f94504d84e..c4ff34d6d185 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -318,12 +318,12 @@ def generate_tensor_like_torch_implementations(): for namespace, funcs in get_overridable_functions().items(): for func in funcs: if func not in testing_overrides: - untested_funcs.append("{}.{}".format(namespace, func.__name__)) + untested_funcs.append("{}.{}".format(namespace.__name__, func.__name__)) msg = ( - "The following functions are not tested for __torch_function__ " - "support, please ensure there is an entry in the dict returned by " - "torch._overrides.get_testing_overrides for this function or if a " - "__torch_function__ override does not make sense, add an entry to " + "The following functions are not tested for __torch_function__ \n" + "support, please ensure there is an entry in the dict returned by \n" + "torch._overrides.get_testing_overrides for this function or if a \n" + "__torch_function__ override does not make sense, add an entry to \n" "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}" ) assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) From b87b765ff1a6e26b4e8fe7445c46ad911a246e53 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 10 Dec 2020 10:06:48 +0000 Subject: [PATCH 41/63] add the new torch.linalg.svd to test_overrides --- torch/overrides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/overrides.py b/torch/overrides.py index 2af6e36ea914..64716d846689 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -795,6 +795,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, + torch.linalg.svd: lambda input, full_matrices=True, compute_uv=True, out=None: -1, torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1, torch.swapaxes: lambda input, dim0, dim1: -1, torch.swapdims: lambda input, axis0, axis1: -1, From a349569f6c32e7502845cb5c3f180c6ff0e4875b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 13:46:19 +0000 Subject: [PATCH 42/63] this is needed after e391dbc1b5 --- aten/src/ATen/native/native_functions.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6cecd3c05743..9699e62cd8e0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10102,6 +10102,7 @@ variants: function - func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: linalg - func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) From e30248eeb38b47a4a9a7b882ea57251e3023c124 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 13:46:57 +0000 Subject: [PATCH 43/63] this doesn't have to be a method --- aten/src/ATen/native/native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9699e62cd8e0..1119d11f70c8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10108,7 +10108,7 @@ - func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) python_module: linalg use_c10_dispatcher: full - variants: method, function + variants: function - func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor use_c10_dispatcher: full From 3ddb6accfc3c6747195d79ba614b4db19e4dd4fb Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 13:49:20 +0000 Subject: [PATCH 44/63] use torch.empty instead of torch.tensor --- test/test_linalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 5eeecb08e06f..ab100208d8e7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1872,9 +1872,9 @@ def test_nuclear_norm_exceptions_old(self, device): def test_svd(self, device, dtype): def run_test(dims, some, compute_uv): x = torch.randn(*dims, dtype=dtype, device=device) - outu = torch.tensor((), dtype=dtype, device=device) - outs = torch.tensor((), dtype=dtype, device=device) - outv = torch.tensor((), dtype=dtype, device=device) + outu = torch.empty(0, dtype=dtype, device=device) + outs = torch.empty(0, dtype=dtype, device=device) + outv = torch.empty(0, dtype=dtype, device=device) torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) if compute_uv: From 823e6a841387a78cf9a3f2847611eda15ce38dfa Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 13:50:22 +0000 Subject: [PATCH 45/63] remove unnecessary import --- test/test_linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index ab100208d8e7..4f0c6edfc1b9 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1937,7 +1937,6 @@ def test_svd_no_singularvectors(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack def test_svd_lowrank(self, device): - import torch from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix dtype = torch.double From 52aadbe0af937960b2802ef032f1723c98719be5 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 13:53:10 +0000 Subject: [PATCH 46/63] use the proper @dtypes decorator --- test/test_linalg.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 4f0c6edfc1b9..8995854728d6 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1927,20 +1927,20 @@ def run_test(dims, some, compute_uv): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_svd_no_singularvectors(self, device): + @dtypes(torch.float) + def test_svd_no_singularvectors(self, device, dtype): for size in [(5, 5), (5, 20), (20, 5)]: - a = torch.randn(*size, device=device) + a = torch.randn(*size, device=device, dtype=dtype) u, s_expect, v = torch.svd(a) u, s_actual, v = torch.svd(a, compute_uv=False) self.assertEqual(s_expect, s_actual, msg="Singular values don't match") @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_svd_lowrank(self, device): + @dtypes(torch.double) + def test_svd_lowrank(self, device, dtype): from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix - dtype = torch.double - def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options): density = options.pop('density', 1) if isinstance(matrix_size, int): From ea0aca4b09822a651c3354ccf7b45d0d6c873729 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 14:29:05 +0000 Subject: [PATCH 47/63] don't use Tensor --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 8995854728d6..50fc4c350560 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2117,7 +2117,7 @@ def is_empty(x): self.assertEqual(USV.S, np_s) assert is_empty(USV.V) # check linalg.svd vs linalg.svd(out=...) - out = (torch.Tensor(), torch.empty_like(USV.S), torch.Tensor()) + out = (torch.empty_like(USV.U), torch.empty_like(USV.S), torch.empty_like(USV.V)) USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out) assert USV.U is out[0] assert USV.S is out[1] From 9bc46f6e16b47d8ee1ce4bda6eb340bd01654bd9 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 14:40:18 +0000 Subject: [PATCH 48/63] typo --- torch/_torch_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 03241cb3c7a3..fd9e701f6a02 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8045,7 +8045,7 @@ def merge_dicts(*dicts): r""" svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -Compute the singular value decomposition of either a matrix or batch of +Computes the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a namedtuple ``(U, S, V)``, such that :math:`input = U \times diag(S) \times V^T`, where :math:`V^T` is the transpose of ``V``. If the inputs are batches, then From da1f2ae4b55e2477e87b39639dd190daf8802f2e Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 15:26:23 +0000 Subject: [PATCH 49/63] improve docs --- torch/_torch_docs.py | 28 +++++++++++++++------------- torch/linalg/__init__.py | 8 ++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index fd9e701f6a02..067e9800f016 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8048,10 +8048,11 @@ def merge_dicts(*dicts): Computes the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a namedtuple ``(U, S, V)``, such that :math:`input = U \times diag(S) \times -V^T`, where :math:`V^T` is the transpose of ``V``. If the inputs are batches, then -returns batched outputs for all of ``U``, `S`` and ``V``. +V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch +of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch +dimensions as :attr:`input`. -The original tensor can be reconstructed by:: +When :attr:`input` is a tensor, it can be reconstructed by:: U @ diag(S) @ V.T @@ -8068,20 +8069,21 @@ def merge_dicts(*dicts): The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will always be real-valued, even if :attr:`input` is complex. -.. warning:: ``torch.svd`` is deprecated. Please use :meth:`~torch.linalg.svd` - instead, which provides a better compatibility with +.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` + :func:`~torch.linalg.svd` instead, which is similar to NumPy's ``numpy.linalg.svd``. -.. note:: **Differences with** :meth:`~torch.linalg.svd`: +.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: - * :attr:`some` is the opposite of :meth:`~torch.linalg.svd`'s - :attr:`full_matricies`. Note that default value for both is - ``True``, so the default behavior is effectively the opposite. + * :attr:`some` is the opposite of ``torch.linalg.`` + :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + default value for both is ``True``, so the default behavior is + effectively the opposite. - * it returns ``V``, whereas :meth:`~torch.linalg.svd` returns - ``Vh``. The result is that when using :meth:`~torch.svd` you - need to manually transpose ``V`` in order to reconstruct the - original matrix. + * it returns ``V``, whereas ``torch.linalg.`` + :func:`~torch.linalg.svd` returns ``Vh``. The result is that + when using ``svd`` you need to manually transpose + ``V`` in order to reconstruct the original matrix. * If :attr:`compute_uv=False`, it returns zero-filled tensors for ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index c22ce545406b..31340070a441 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -406,13 +406,13 @@ svd = _add_docstr(_linalg.linalg_svd, r""" linalg.svd(input, full_matrices=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) -Compute the singular value decomposition of either a matrix or batch of +Computes the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a namedtuple ``(U, S, Vh)``, such that :math:`input = U \times diag(S) \times -Vh`. If the inputs are batches, then returns batched outputs for all of ``U``, -``S`` and ``Vh``. +Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are +also batched with the same batch dimensions as :attr:`input`. -The original tensor can be reconstructed by:: +When :attr:`input` is a tensor, it can be reconstructed by:: U @ diag(S) @ Vh From 5445d3f725adeff909927d04770bcfa8064ae7fd Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 15:30:49 +0000 Subject: [PATCH 50/63] use the correct nccl version (hopefully) --- third_party/nccl/nccl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 31b5bb6f6447..033d799524fb 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 31b5bb6f6447da98b9110c605465f9c09621074e +Subproject commit 033d799524fb97629af5ac2f609de367472b2696 From 6d04a6e8763cefe2dafa91ba03e1bf039a973c7d Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 16:03:58 +0000 Subject: [PATCH 51/63] now the underlying op is _svd_helper, and svd is only a thin layer on top of it --- tools/autograd/gen_variable_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 9c8800786fed..231af480cacc 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -78,7 +78,7 @@ 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', - 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', 'svd', + 'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr', '_svd_helper', '_fft_c2c', '_fft_r2c', } From a2e2781e0f17e8a38e796cfd9aa0b92dfb9b6d06 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 21 Dec 2020 16:26:14 +0000 Subject: [PATCH 52/63] Revert "Improve the error message generated by test_overrides." This reverts commit 8b30bcbbd0c4ed5864afd16d009c79164516f019. --- test/test_overrides.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_overrides.py b/test/test_overrides.py index c4ff34d6d185..95f94504d84e 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -318,12 +318,12 @@ def generate_tensor_like_torch_implementations(): for namespace, funcs in get_overridable_functions().items(): for func in funcs: if func not in testing_overrides: - untested_funcs.append("{}.{}".format(namespace.__name__, func.__name__)) + untested_funcs.append("{}.{}".format(namespace, func.__name__)) msg = ( - "The following functions are not tested for __torch_function__ \n" - "support, please ensure there is an entry in the dict returned by \n" - "torch._overrides.get_testing_overrides for this function or if a \n" - "__torch_function__ override does not make sense, add an entry to \n" + "The following functions are not tested for __torch_function__ " + "support, please ensure there is an entry in the dict returned by " + "torch._overrides.get_testing_overrides for this function or if a " + "__torch_function__ override does not make sense, add an entry to " "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}" ) assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) From bcf7461da874885cdcc027f0173e58bc65472766 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 22 Dec 2020 17:47:41 +0000 Subject: [PATCH 53/63] add an OpInfo for torch.linalg.svd, and adapt sample_inputs_svd to generate inputs which are suitable for both svd and linalg.svd. Thanks to @IvanYashchuk for the help :) --- .../_internal/common_methods_invocations.py | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b7c8ed9567a1..4f10a98e62e7 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -395,7 +395,7 @@ def sample_inputs(self, device, dtype, requires_grad=False): ] -def sample_inputs_svd(op_info, device, dtype, requires_grad=False): +def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): """ This function generates input for torch.svd with distinct singular values so that autograd is always stable. Matrices of different size: @@ -408,6 +408,14 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value + # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # along different dimensions when needed (this is used by + # test_cases2:wide_all and wide_all_batched below) + if is_linalg_svd: + def slice_V(v): return v[..., :(S - 2), :] + else: + def slice_V(v): return v[..., :, :(S - 2)] + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -437,11 +445,11 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): ) test_cases2 = ( # some=False (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][:, :(S - 2)]))), # 'wide_all' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], lambda usv: (abs(usv[0][:, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :(S - 2), :], - lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide_all_batched' + lambda usv: (abs(usv[0]), usv[1], abs(slice_V(usv[2])))), # 'wide_all_batched' (random_fullrank_matrix_distinct_singular_value(S, 2, dtype=dtype).to(device)[..., :, :(S - 2)], lambda usv: (abs(usv[0][..., :, :(S - 2)]), usv[1], abs(usv[2]))), # 'tall_all_batched' ) @@ -449,15 +457,27 @@ def sample_inputs_svd(op_info, device, dtype, requires_grad=False): out = [] for a, out_fn in test_cases1: a.requires_grad = requires_grad - out.append(SampleInput(a, output_process_fn_grad=out_fn)) + if is_linalg_svd: + kwargs = {'full_matrices': False} + else: + kwargs = {'some': True} + out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) for a, out_fn in test_cases2: a.requires_grad = requires_grad - kwargs = {'some': False} + if is_linalg_svd: + kwargs = {'full_matrices': True} + else: + kwargs = {'some': False} out.append(SampleInput(a, kwargs=kwargs, output_process_fn_grad=out_fn)) return out +def sample_inputs_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=False) + +def sample_inputs_linalg_svd(op_info, device, dtype, requires_grad=False): + return _sample_inputs_svd(op_info, device, dtype, requires_grad, is_linalg_svd=True) def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): """ @@ -965,6 +985,19 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + OpInfo('linalg_svd', + op=torch.linalg.svd, + dtypes=floating_and_complex_types(), + test_inplace_grad=False, + supports_tensor_out=False, + sample_inputs_func=sample_inputs_linalg_svd, + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + skips=( + # gradgrad checks are slow + SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + # cuda gradchecks are very slow + # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), OpInfo('pinverse', op=torch.pinverse, dtypes=floating_and_complex_types(), From ba92c1a14a9507f391b34ae1ccf5aa66d800ba05 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 22 Dec 2020 20:51:49 +0000 Subject: [PATCH 54/63] typo --- torch/testing/_internal/common_methods_invocations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4f10a98e62e7..3fa8ae3fa1ad 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -985,7 +985,7 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), - OpInfo('linalg_svd', + OpInfo('linalg.svd', op=torch.linalg.svd, dtypes=floating_and_complex_types(), test_inplace_grad=False, From 9d9cd0370196b143449991a3dfdcff0d3acefb0f Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 23 Dec 2020 10:34:08 +0000 Subject: [PATCH 55/63] specify the aten_name --- torch/testing/_internal/common_methods_invocations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 633b5648f538..4180e52bc6f2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1045,6 +1045,7 @@ def sample_inputs_pinverse(op_info, device, dtype, requires_grad=False): SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), OpInfo('linalg.svd', op=torch.linalg.svd, + aten_name='linalg_svd', dtypes=floating_and_complex_types(), test_inplace_grad=False, supports_tensor_out=False, From c3e4de6ab29f0ed4267bea20264e93e660a03400 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 23 Dec 2020 10:36:30 +0000 Subject: [PATCH 56/63] fix indent --- test/test_namedtuple_return_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 77b710ffc4c2..071ebb0ccdef 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -13,7 +13,7 @@ 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", - '_svd_helper', 'linalg_svd', + '_svd_helper', 'linalg_svd', } From ec87163d6914e99f6d7bbdf72b4612f27c26974b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 23 Dec 2020 10:41:57 +0000 Subject: [PATCH 57/63] fix flake8 --- torch/testing/_internal/common_methods_invocations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4180e52bc6f2..844e6e5d2a38 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -462,9 +462,11 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv # along different dimensions when needed (this is used by # test_cases2:wide_all and wide_all_batched below) if is_linalg_svd: - def slice_V(v): return v[..., :(S - 2), :] + def slice_V(v): + return v[..., :(S - 2), :] else: - def slice_V(v): return v[..., :, :(S - 2)] + def slice_V(v): + return v[..., :, :(S - 2)] test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", From 958321ee3188aab1ce5d1740dc5c5d95993d315d Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Wed, 23 Dec 2020 13:53:47 +0000 Subject: [PATCH 58/63] fix test_namedtuple_return: with the new logic, the 'a' argument is passed automatically --- test/test_namedtuple_return_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 071ebb0ccdef..fad505c79ab2 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -66,7 +66,7 @@ def test_namedtuple_return(self): op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), - op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False), + op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] def get_func(f): From 8dbbfe5df6d85d08bd6447598634af96aacab04b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 5 Jan 2021 10:47:55 +0000 Subject: [PATCH 59/63] input/svd input --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 9c2dc13fc3c3..3774e31e3df8 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1428,14 +1428,14 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some std::tuple svd(const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); return at::_svd_helper(self, some, compute_uv); } std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, const Tensor& self, bool some, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); Tensor U_tmp, S_tmp, V_tmp; std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv); U.resize_as_(U_tmp).copy_(U_tmp); @@ -1458,7 +1458,7 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { TORCH_CHECK(self.dim() >= 2, - "input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + "svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); bool some = !full_matrices; Tensor U, S, V; From be456ab740e5829a69515e1eaa814f6f02e48164 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 5 Jan 2021 11:12:43 +0000 Subject: [PATCH 60/63] remove redundant sentences --- torch/_torch_docs.py | 6 +----- torch/linalg/__init__.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 28b1b451592b..d204afdb286e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8144,15 +8144,11 @@ def merge_dicts(*dicts): Computes the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a -namedtuple ``(U, S, V)``, such that :math:`input = U \times diag(S) \times +namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch dimensions as :attr:`input`. -When :attr:`input` is a tensor, it can be reconstructed by:: - - U @ diag(S) @ V.T - If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned `U` and `V` matrices will contain only diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 3e984b536627..4c724b0b7e4c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -408,14 +408,10 @@ Computes the singular value decomposition of either a matrix or batch of matrices :attr:`input`." The singular value decomposition is represented as a -namedtuple ``(U, S, Vh)``, such that :math:`input = U \times diag(S) \times +namedtuple ``(U, S, Vh)``, such that :math:`input = U \mathbin{@} diag(S) \times Vh`. If :attr:`input` is a batch of tensors, then ``U``, ``S``, and ``Vh`` are also batched with the same batch dimensions as :attr:`input`. -When :attr:`input` is a tensor, it can be reconstructed by:: - - U @ diag(S) @ Vh - If :attr:`full_matrices` is ``False`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned `U` and `V` matrices will contain only From e483647c3f44250c1727e22b301eb45a85c43461 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 5 Jan 2021 15:08:50 +0000 Subject: [PATCH 61/63] check that the linalg.svd output tensors are on the correct device --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 12 +++++++++--- test/test_linalg.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 3774e31e3df8..f0b36d0fdbac 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1473,13 +1473,19 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr } } +static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) { + TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + std::tuple linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT, const Tensor& self, bool full_matrices, bool compute_uv) { Tensor U_tmp, S_tmp, VT_tmp; std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv); - U.resize_as_(U_tmp).copy_(U_tmp); - S.resize_as_(S_tmp).copy_(S_tmp); - VT.resize_as_(VT_tmp).copy_(VT_tmp); + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("S", S_tmp, S); + svd_resize_and_copy("V", VT_tmp, VT); return std::tuple(U, S, VT); } diff --git a/test/test_linalg.py b/test/test_linalg.py index ca6234f13d38..d79ec5ef09b5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2124,6 +2124,17 @@ def is_empty(x): assert USV.V is out[2] self.assertEqual(USV.S, np_s) + @skipCPUIfNoLapack + @onlyCUDA + @dtypes(torch.float) + def test_linalg_svd_out_different_device(self, device, dtype): + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + u = torch.empty((5, 5), device='cpu', dtype=dtype) + s = torch.empty((5,), device='cpu', dtype=dtype) + v = torch.empty((7, 7), device='cpu', dtype=dtype) + with self.assertRaisesRegex(RuntimeError, 'svd output tensor U is on the wrong device: expected cuda:.* got cpu'): + torch.linalg.svd(t, out=(u, s, v)) + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix From 24c65065996c74e0b6855fcafe50858093db155f Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 5 Jan 2021 15:20:56 +0000 Subject: [PATCH 62/63] flake8 --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index d79ec5ef09b5..ef1b22540e24 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2128,7 +2128,7 @@ def is_empty(x): @onlyCUDA @dtypes(torch.float) def test_linalg_svd_out_different_device(self, device, dtype): - t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda + t = torch.randn(5, 7, device=device, dtype=dtype) # this is on cuda u = torch.empty((5, 5), device='cpu', dtype=dtype) s = torch.empty((5,), device='cpu', dtype=dtype) v = torch.empty((7, 7), device='cpu', dtype=dtype) From c39f2ef902cead401f61fe95c2748ec8cb2ed4ff Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 5 Jan 2021 22:16:59 +0000 Subject: [PATCH 63/63] skip this test if we don't have magma --- test/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index ef1b22540e24..f7cea014458c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -2124,6 +2124,7 @@ def is_empty(x): assert USV.V is out[2] self.assertEqual(USV.S, np_s) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @onlyCUDA @dtypes(torch.float)