diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index ecc1089887cd..afca92b11961 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1566,6 +1566,8 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some VT_working_copy.zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } @@ -1596,8 +1598,8 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, 1. the 2nd parameter is bool some=True, which if effectively the opposite of full_matrices=True - 2. svd returns V, while linalg.svd returns VT. To accommodate the - difference, we transpose() V upon return + 2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs). + To accommodate the difference, we transpose() and conj() V upon return */ std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { @@ -1608,7 +1610,7 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr Tensor U, S, V; std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); if (compute_uv) { - Tensor VT = V.transpose(-2, -1); + Tensor VT = V.conj().transpose(-2, -1); return std::make_tuple(U, S, VT); } else { Tensor empty_U = at::empty({0}, self.options()); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index a4cec8813bfe..674c91597ae9 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) { // If not Hermitian use singular value decomposition, else use eigenvalue decomposition if (!hermitian) { - // until https://github.com/pytorch/pytorch/issues/45821 is resolved - // svd() returns conjugated V for complex-valued input - Tensor U, S, V_conj; + Tensor U, S, V; // TODO: replace input.svd with linalg_svd - std::tie(U, S, V_conj) = input.svd(); + // using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755 + std::tie(U, S, V) = input.svd(); Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); - // computes V @ diag(S_pseudoinv) @ U.T.conj() - // TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved - return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); + // computes V @ diag(S_pseudoinv) @ U.conj().T + return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); } else { Tensor S, U; std::tie(S, U) = at::linalg_eigh(input); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 2cc7563ff23b..2af1cbb29312 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -2252,6 +2252,8 @@ std::tuple _svd_helper_cuda_legacy(const Tensor& self, b VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index 8c6751abd86e..4d19920762af 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -204,8 +204,6 @@ inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] { _apply_svd_lib_gesvdj(self_working_copy, U, S, VT, infos, compute_uv, some); }); - - VT = VT.conj(); } // call cusolver gesvdj batched function to calculate svd @@ -256,8 +254,6 @@ inline static void apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Te AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdjBatched", [&] { _apply_svd_lib_gesvdjBatched(self_working_copy, U, S, VT, infos, compute_uv); }); - - VT = VT.conj(); } // entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched diff --git a/test/test_linalg.py b/test/test_linalg.py index fa52626ea332..62fe94da6fe7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1873,17 +1873,18 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option actual_rank, size, batches = 2, (17, 4), () run_subtest(actual_rank, size, batches, device, jitted) - @onlyCPU + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.cfloat) def test_svd_complex(self, device, dtype): + # this test verifies that torch.svd really returns V and not V.conj() + # see: https://github.com/pytorch/pytorch/issues/45821 t = torch.randn((10, 10), dtype=dtype, device=device) U, S, V = torch.svd(t, some=False) - # note: from the math point of view, it is weird that we need to use - # V.T instead of V.T.conj(): torch.svd has a buggy behavior for - # complex numbers and it's deprecated. You should use torch.linalg.svd - # instead. - t2 = U @ torch.diag(S).type(dtype) @ V.T + # verify that t ≈ t2 + # t2 = U @ diag(S) @ Vᴴ + # Vᴴ is the conjugate transpose of V + t2 = U @ torch.diag(S).type(dtype) @ V.conj().T self.assertEqual(t, t2) def _test_svd_helper(self, shape, some, col_maj, device, dtype): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index c1063b960664..93b0d10393f0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8312,47 +8312,45 @@ def merge_dicts(*dicts): svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) Computes the singular value decomposition of either a matrix or batch of -matrices :attr:`input`." The singular value decomposition is represented as a -namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times -V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch -of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch -dimensions as :attr:`input`. +matrices :attr:`input`. The singular value decomposition is represented as a +namedtuple (`U,S,V`), such that +:attr:`input` = `U` diag(`S`) `Vᴴ`, +where `Vᴴ` is the transpose of `V` for the real-valued inputs, +or the conjugate transpose of `V` for the complex-valued inputs. +If :attr:`input` is a batch of tensors, then `U`, `S`, and `V` are also +batched with the same batch dimensions as :attr:`input`. If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are -``m`` and ``n``, then the returned `U` and `V` matrices will contain only -:math:`min(n, m)` orthonormal columns. +`m` and `n`, then the returned `U` and `V` matrices will contain only +min(`n, m`) orthonormal columns. If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be -zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +zero-filled matrices of shape `(m × m)` and `(n × n)` respectively, and the same device as :attr:`input`. The :attr:`some` -argument has no effect when :attr:`compute_uv` is False. +argument has no effect when :attr:`compute_uv` is ``False``. -The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +Supports input of float, double, cfloat and cdouble data types. +The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will always be real-valued, even if :attr:`input` is complex. -.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` - :func:`~torch.linalg.svd` instead, which is similar to NumPy's +.. warning:: :func:`torch.svd` is deprecated. Please use + :func:`torch.linalg.svd` instead, which is similar to NumPy's ``numpy.linalg.svd``. -.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: +.. note:: Differences with :func:`torch.linalg.svd`: - * :attr:`some` is the opposite of ``torch.linalg.`` - :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matricies`. Note that default value for both is ``True``, so the default behavior is effectively the opposite. - * it returns ``V``, whereas ``torch.linalg.`` - :func:`~torch.linalg.svd` returns ``Vh``. The result is that - when using ``svd`` you need to manually transpose - ``V`` in order to reconstruct the original matrix. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns `Vᴴ`. - * If :attr:`compute_uv=False`, it returns zero-filled tensors for - ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + * If :attr:`compute_uv=False`, :func:`torch.svd` returns zero-filled tensors for + ``U`` and ``Vh``, whereas :func:`torch.linalg.svd` returns empty tensors. -Supports real-valued and complex-valued input. - .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. @@ -8382,10 +8380,10 @@ def merge_dicts(*dicts): `U` and `V` tensors. Args: - input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more - batch dimensions consisting of :math:`m \times n` matrices. + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m × n)` matrices. some (bool, optional): controls whether to compute the reduced or full decomposition, and - consequently the shape of returned ``U`` and ``V``. Defaults to True. + consequently the shape of returned `U` and `V`. Defaults to True. compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 27dd4ccce649..38248ca93d4e 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1949,21 +1949,16 @@ Tensor svd_backward(const std::vector &grads, const T auto gsigma = grads[1]; auto u = raw_u; - // Currently torch.svd for complex dtypes returns the conjugate of V, - // while the backward formula is derived with just V (without the conjugation) - // therefore here we need to conjugate the V output of SVD and grads[2]. - // Once https://github.com/pytorch/pytorch/issues/45821 is resolved - // extra .conj(), that are marked below in the code, shall be removed. - auto v = raw_v.conj(); // TODO: remove .conj() + auto v = raw_v; auto gu = grads[0]; - auto gv = grads[2].conj(); // TODO: remove .conj() + auto gv = grads[2]; if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); - v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() + v = raw_v.narrow(-1, 0, k); if (gu.defined()) { gu = gu.narrow(-1, 0, k); } diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index db19360a07f8..a0ccf559366c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -846,16 +846,26 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice # along different dimensions when needed (this is used by # test_cases2:wide_all and wide_all_batched below) if is_linalg_svd: def slice_V(v): return v[..., :(S - 2), :] + + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0] + return u00 * v00_conj else: def slice_V(v): return v[..., :, :(S - 2)] + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0].conj() + return u00 * v00_conj + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -866,12 +876,10 @@ def slice_V(v): lambda usv: abs(usv[0])), # 'check_grad_u' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), lambda usv: abs(usv[2])), # 'check_grad_v' - # TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj() - # once https://github.com/pytorch/pytorch/issues/45821 is resolved # this test is important as it checks the additional term that is non-zero only for complex-valued inputs # and when the loss function depends both on 'u' and 'v' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), - lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv' + uv_loss), # 'check_grad_uv' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)],