From ddf26816d3ca54ce7f3513f618fac93ce67d06e9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 25 Jan 2021 13:58:09 -0800 Subject: [PATCH] Make torch.svd return V, not V.conj() for complex inputs (#51012) Summary: **BC-breaking note:** torch.svd() added support for complex inputs in PyTorch 1.7, but was not documented as doing so. The complex "V" tensor returned was actually the complex conjugate of what's expected. This PR fixes the discrepancy. This will silently break all users of torch.svd() with complex inputs. **Original PR Summary:** This PR resolves https://github.com/pytorch/pytorch/issues/45821. The problem was that when introducing the support of complex inputs for `torch.svd` it was overlooked that LAPACK/MAGMA returns the conjugate transpose of V matrix, not just the transpose of V. So `torch.svd` was silently returning U, S, V.conj() instead of U, S, V. Behavior of `torch.linalg.pinv`, `torch.pinverse` and `torch.linalg.svd` (they depend on `torch.svd`) is not changed in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/51012 Reviewed By: bdhirsh Differential Revision: D26047593 Pulled By: albanD fbshipit-source-id: d1e08dbc3aab9ce1150a95806ef3b5da98b5d3ca --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 8 +-- aten/src/ATen/native/LinearAlgebra.cpp | 12 ++--- .../ATen/native/cuda/BatchLinearAlgebra.cu | 2 + .../ATen/native/cuda/BatchLinearAlgebraLib.cu | 4 -- test/test_linalg.py | 13 ++--- torch/_torch_docs.py | 50 +++++++++---------- torch/csrc/autograd/FunctionsManual.cpp | 11 ++-- .../_internal/common_methods_invocations.py | 16 ++++-- 8 files changed, 58 insertions(+), 58 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index ecc1089887cd..afca92b11961 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1566,6 +1566,8 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some VT_working_copy.zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } @@ -1596,8 +1598,8 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, 1. the 2nd parameter is bool some=True, which if effectively the opposite of full_matrices=True - 2. svd returns V, while linalg.svd returns VT. To accommodate the - difference, we transpose() V upon return + 2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs). + To accommodate the difference, we transpose() and conj() V upon return */ std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { @@ -1608,7 +1610,7 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr Tensor U, S, V; std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); if (compute_uv) { - Tensor VT = V.transpose(-2, -1); + Tensor VT = V.conj().transpose(-2, -1); return std::make_tuple(U, S, VT); } else { Tensor empty_U = at::empty({0}, self.options()); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index a4cec8813bfe..674c91597ae9 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) { // If not Hermitian use singular value decomposition, else use eigenvalue decomposition if (!hermitian) { - // until https://github.com/pytorch/pytorch/issues/45821 is resolved - // svd() returns conjugated V for complex-valued input - Tensor U, S, V_conj; + Tensor U, S, V; // TODO: replace input.svd with linalg_svd - std::tie(U, S, V_conj) = input.svd(); + // using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755 + std::tie(U, S, V) = input.svd(); Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); - // computes V @ diag(S_pseudoinv) @ U.T.conj() - // TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved - return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); + // computes V @ diag(S_pseudoinv) @ U.conj().T + return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); } else { Tensor S, U; std::tie(S, U) = at::linalg_eigh(input); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 4d227cd23a81..8b6405a6cafe 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2260,6 +2260,8 @@ std::tuple _svd_helper_cuda_legacy(const Tensor& self, b VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index 45597a085237..3431091661dd 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -204,8 +204,6 @@ inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] { _apply_svd_lib_gesvdj(self_working_copy, U, S, VT, infos, compute_uv, some); }); - - VT = VT.conj(); } // call cusolver gesvdj batched function to calculate svd @@ -256,8 +254,6 @@ inline static void apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Te AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdjBatched", [&] { _apply_svd_lib_gesvdjBatched(self_working_copy, U, S, VT, infos, compute_uv); }); - - VT = VT.conj(); } // entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched diff --git a/test/test_linalg.py b/test/test_linalg.py index 16d3039149ec..f9e7e03a1f50 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1872,17 +1872,18 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option actual_rank, size, batches = 2, (17, 4), () run_subtest(actual_rank, size, batches, device, jitted) - @onlyCPU + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.cfloat) def test_svd_complex(self, device, dtype): + # this test verifies that torch.svd really returns V and not V.conj() + # see: https://github.com/pytorch/pytorch/issues/45821 t = torch.randn((10, 10), dtype=dtype, device=device) U, S, V = torch.svd(t, some=False) - # note: from the math point of view, it is weird that we need to use - # V.T instead of V.T.conj(): torch.svd has a buggy behavior for - # complex numbers and it's deprecated. You should use torch.linalg.svd - # instead. - t2 = U @ torch.diag(S).type(dtype) @ V.T + # verify that t ≈ t2 + # t2 = U @ diag(S) @ Vᴴ + # Vᴴ is the conjugate transpose of V + t2 = U @ torch.diag(S).type(dtype) @ V.conj().T self.assertEqual(t, t2) def _test_svd_helper(self, shape, some, col_maj, device, dtype): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index c1063b960664..93b0d10393f0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8312,47 +8312,45 @@ def merge_dicts(*dicts): svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) Computes the singular value decomposition of either a matrix or batch of -matrices :attr:`input`." The singular value decomposition is represented as a -namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times -V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch -of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch -dimensions as :attr:`input`. +matrices :attr:`input`. The singular value decomposition is represented as a +namedtuple (`U,S,V`), such that +:attr:`input` = `U` diag(`S`) `Vᴴ`, +where `Vᴴ` is the transpose of `V` for the real-valued inputs, +or the conjugate transpose of `V` for the complex-valued inputs. +If :attr:`input` is a batch of tensors, then `U`, `S`, and `V` are also +batched with the same batch dimensions as :attr:`input`. If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are -``m`` and ``n``, then the returned `U` and `V` matrices will contain only -:math:`min(n, m)` orthonormal columns. +`m` and `n`, then the returned `U` and `V` matrices will contain only +min(`n, m`) orthonormal columns. If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be -zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +zero-filled matrices of shape `(m × m)` and `(n × n)` respectively, and the same device as :attr:`input`. The :attr:`some` -argument has no effect when :attr:`compute_uv` is False. +argument has no effect when :attr:`compute_uv` is ``False``. -The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +Supports input of float, double, cfloat and cdouble data types. +The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will always be real-valued, even if :attr:`input` is complex. -.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` - :func:`~torch.linalg.svd` instead, which is similar to NumPy's +.. warning:: :func:`torch.svd` is deprecated. Please use + :func:`torch.linalg.svd` instead, which is similar to NumPy's ``numpy.linalg.svd``. -.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: +.. note:: Differences with :func:`torch.linalg.svd`: - * :attr:`some` is the opposite of ``torch.linalg.`` - :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matricies`. Note that default value for both is ``True``, so the default behavior is effectively the opposite. - * it returns ``V``, whereas ``torch.linalg.`` - :func:`~torch.linalg.svd` returns ``Vh``. The result is that - when using ``svd`` you need to manually transpose - ``V`` in order to reconstruct the original matrix. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns `Vᴴ`. - * If :attr:`compute_uv=False`, it returns zero-filled tensors for - ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + * If :attr:`compute_uv=False`, :func:`torch.svd` returns zero-filled tensors for + ``U`` and ``Vh``, whereas :func:`torch.linalg.svd` returns empty tensors. -Supports real-valued and complex-valued input. - .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -8382,10 +8380,10 @@ def merge_dicts(*dicts): `U` and `V` tensors. Args: - input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more - batch dimensions consisting of :math:`m \times n` matrices. + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m × n)` matrices. some (bool, optional): controls whether to compute the reduced or full decomposition, and - consequently the shape of returned ``U`` and ``V``. Defaults to True. + consequently the shape of returned `U` and `V`. Defaults to True. compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 27dd4ccce649..38248ca93d4e 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1949,21 +1949,16 @@ Tensor svd_backward(const std::vector &grads, const T auto gsigma = grads[1]; auto u = raw_u; - // Currently torch.svd for complex dtypes returns the conjugate of V, - // while the backward formula is derived with just V (without the conjugation) - // therefore here we need to conjugate the V output of SVD and grads[2]. - // Once https://github.com/pytorch/pytorch/issues/45821 is resolved - // extra .conj(), that are marked below in the code, shall be removed. - auto v = raw_v.conj(); // TODO: remove .conj() + auto v = raw_v; auto gu = grads[0]; - auto gv = grads[2].conj(); // TODO: remove .conj() + auto gv = grads[2]; if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); - v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() + v = raw_v.narrow(-1, 0, k); if (gu.defined()) { gu = gu.narrow(-1, 0, k); } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b2e50a50b43f..52bb6b759de9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -858,16 +858,26 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice # along different dimensions when needed (this is used by # test_cases2:wide_all and wide_all_batched below) if is_linalg_svd: def slice_V(v): return v[..., :(S - 2), :] + + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0] + return u00 * v00_conj else: def slice_V(v): return v[..., :, :(S - 2)] + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0].conj() + return u00 * v00_conj + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -878,12 +888,10 @@ def slice_V(v): lambda usv: abs(usv[0])), # 'check_grad_u' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), lambda usv: abs(usv[2])), # 'check_grad_v' - # TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj() - # once https://github.com/pytorch/pytorch/issues/45821 is resolved # this test is important as it checks the additional term that is non-zero only for complex-valued inputs # and when the loss function depends both on 'u' and 'v' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), - lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv' + uv_loss), # 'check_grad_uv' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)],