Skip to content

Commit

Permalink
Make torch.svd return V, not V.conj() for complex inputs (#51012)
Browse files Browse the repository at this point in the history
Summary:
**BC-breaking note:**

torch.svd() added support for complex inputs in PyTorch 1.7, but was not documented as doing so. The complex "V" tensor returned was actually the complex conjugate of what's expected. This PR fixes the discrepancy.

This will silently break all users of torch.svd() with complex inputs.

**Original PR Summary:**

This PR resolves #45821.

The problem was that when introducing the support of complex inputs for `torch.svd` it was overlooked that LAPACK/MAGMA returns the conjugate transpose of V matrix, not just the transpose of V. So `torch.svd` was silently returning U, S, V.conj() instead of U, S, V.

Behavior of `torch.linalg.pinv`, `torch.pinverse` and `torch.linalg.svd` (they depend on `torch.svd`) is not changed in this PR.

Pull Request resolved: #51012

Reviewed By: bdhirsh

Differential Revision: D26047593

Pulled By: albanD

fbshipit-source-id: d1e08dbc3aab9ce1150a95806ef3b5da98b5d3ca
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Jan 25, 2021
1 parent f8eefbd commit ddf2681
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 58 deletions.
8 changes: 5 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1566,6 +1566,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
VT_working_copy.zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
// Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V.
VT_working_copy = VT_working_copy.conj();
VT_working_copy.transpose_(-2, -1);

This comment has been minimized.

Copy link
@vadimkantorov

vadimkantorov Jan 25, 2021

Contributor

Seems one more usecase for conj_transpose / conj_t convenience methods :)

return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}
Expand Down Expand Up @@ -1596,8 +1598,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
1. the 2nd parameter is bool some=True, which if effectively the opposite
of full_matrices=True
2. svd returns V, while linalg.svd returns VT. To accommodate the
difference, we transpose() V upon return
2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs).
To accommodate the difference, we transpose() and conj() V upon return
*/

std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) {
Expand All @@ -1608,7 +1610,7 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matr
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
Tensor VT = V.conj().transpose(-2, -1);
return std::make_tuple(U, S, VT);
} else {
Tensor empty_U = at::empty({0}, self.options());
Expand Down
12 changes: 5 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {

// If not Hermitian use singular value decomposition, else use eigenvalue decomposition
if (!hermitian) {
// until https://github.com/pytorch/pytorch/issues/45821 is resolved
// svd() returns conjugated V for complex-valued input
Tensor U, S, V_conj;
Tensor U, S, V;
// TODO: replace input.svd with linalg_svd
std::tie(U, S, V_conj) = input.svd();
// using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
std::tie(U, S, V) = input.svd();
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
// computes V @ diag(S_pseudoinv) @ U.T.conj()
// TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved
return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
// computes V @ diag(S_pseudoinv) @ U.conj().T
return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
} else {
Tensor S, U;
std::tie(S, U) = at::linalg_eigh(input);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -2260,6 +2260,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
// Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V.
VT_working_copy = VT_working_copy.conj();
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}
Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu
Expand Up @@ -204,8 +204,6 @@ inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] {
_apply_svd_lib_gesvdj<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv, some);
});

VT = VT.conj();
}

// call cusolver gesvdj batched function to calculate svd
Expand Down Expand Up @@ -256,8 +254,6 @@ inline static void apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Te
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
_apply_svd_lib_gesvdjBatched<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv);
});

VT = VT.conj();
}

// entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched
Expand Down
13 changes: 7 additions & 6 deletions test/test_linalg.py
Expand Up @@ -1872,17 +1872,18 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option
actual_rank, size, batches = 2, (17, 4), ()
run_subtest(actual_rank, size, batches, device, jitted)

@onlyCPU
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.cfloat)
def test_svd_complex(self, device, dtype):
# this test verifies that torch.svd really returns V and not V.conj()
# see: https://github.com/pytorch/pytorch/issues/45821
t = torch.randn((10, 10), dtype=dtype, device=device)
U, S, V = torch.svd(t, some=False)
# note: from the math point of view, it is weird that we need to use
# V.T instead of V.T.conj(): torch.svd has a buggy behavior for
# complex numbers and it's deprecated. You should use torch.linalg.svd
# instead.
t2 = U @ torch.diag(S).type(dtype) @ V.T
# verify that t ≈ t2
# t2 = U @ diag(S) @ Vᴴ
# Vᴴ is the conjugate transpose of V
t2 = U @ torch.diag(S).type(dtype) @ V.conj().T
self.assertEqual(t, t2)

def _test_svd_helper(self, shape, some, col_maj, device, dtype):
Expand Down
50 changes: 24 additions & 26 deletions torch/_torch_docs.py
Expand Up @@ -8312,47 +8312,45 @@ def merge_dicts(*dicts):
svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor)
Computes the singular value decomposition of either a matrix or batch of
matrices :attr:`input`." The singular value decomposition is represented as a
namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times
V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch
of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch
dimensions as :attr:`input`.
matrices :attr:`input`. The singular value decomposition is represented as a
namedtuple (`U,S,V`), such that
:attr:`input` = `U` diag(`S`) `Vᴴ`,
where `Vᴴ` is the transpose of `V` for the real-valued inputs,
or the conjugate transpose of `V` for the complex-valued inputs.
If :attr:`input` is a batch of tensors, then `U`, `S`, and `V` are also
batched with the same batch dimensions as :attr:`input`.
If :attr:`some` is ``True`` (default), the method returns the reduced singular
value decomposition i.e., if the last two dimensions of :attr:`input` are
``m`` and ``n``, then the returned `U` and `V` matrices will contain only
:math:`min(n, m)` orthonormal columns.
`m` and `n`, then the returned `U` and `V` matrices will contain only
min(`n, m`) orthonormal columns.
If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be
zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)`
zero-filled matrices of shape `(m × m)` and `(n × n)`
respectively, and the same device as :attr:`input`. The :attr:`some`
argument has no effect when :attr:`compute_uv` is False.
argument has no effect when :attr:`compute_uv` is ``False``.
The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will
Supports input of float, double, cfloat and cdouble data types.
The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will
always be real-valued, even if :attr:`input` is complex.
.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.``
:func:`~torch.linalg.svd` instead, which is similar to NumPy's
.. warning:: :func:`torch.svd` is deprecated. Please use
:func:`torch.linalg.svd` instead, which is similar to NumPy's
``numpy.linalg.svd``.
.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`:
.. note:: Differences with :func:`torch.linalg.svd`:
* :attr:`some` is the opposite of ``torch.linalg.``
:func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that
* :attr:`some` is the opposite of
:func:`torch.linalg.svd`'s :attr:`full_matricies`. Note that
default value for both is ``True``, so the default behavior is
effectively the opposite.
* it returns ``V``, whereas ``torch.linalg.``
:func:`~torch.linalg.svd` returns ``Vh``. The result is that
when using ``svd`` you need to manually transpose
``V`` in order to reconstruct the original matrix.
* :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns `Vᴴ`.
* If :attr:`compute_uv=False`, it returns zero-filled tensors for
``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns
* If :attr:`compute_uv=False`, :func:`torch.svd` returns zero-filled tensors for
``U`` and ``Vh``, whereas :func:`torch.linalg.svd` returns
empty tensors.
Supports real-valued and complex-valued input.
.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices,
then the singular values of each matrix in the batch is returned in descending order.
Expand Down Expand Up @@ -8382,10 +8380,10 @@ def merge_dicts(*dicts):
`U` and `V` tensors.
Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of :math:`m \times n` matrices.
input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more
batch dimensions consisting of `(m × n)` matrices.
some (bool, optional): controls whether to compute the reduced or full decomposition, and
consequently the shape of returned ``U`` and ``V``. Defaults to True.
consequently the shape of returned `U` and `V`. Defaults to True.
compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True.
Keyword args:
Expand Down
11 changes: 3 additions & 8 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1949,21 +1949,16 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
auto gsigma = grads[1];

auto u = raw_u;
// Currently torch.svd for complex dtypes returns the conjugate of V,
// while the backward formula is derived with just V (without the conjugation)
// therefore here we need to conjugate the V output of SVD and grads[2].
// Once https://github.com/pytorch/pytorch/issues/45821 is resolved
// extra .conj(), that are marked below in the code, shall be removed.
auto v = raw_v.conj(); // TODO: remove .conj()
auto v = raw_v;
auto gu = grads[0];
auto gv = grads[2].conj(); // TODO: remove .conj()
auto gv = grads[2];

if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj()
v = raw_v.narrow(-1, 0, k);
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
Expand Down
16 changes: 12 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -858,16 +858,26 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv
"""
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

# svd and linalg.svd returns V and V.T, respectively. So we need to slice
# svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice
# along different dimensions when needed (this is used by
# test_cases2:wide_all and wide_all_batched below)
if is_linalg_svd:
def slice_V(v):
return v[..., :(S - 2), :]

def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0]
return u00 * v00_conj
else:
def slice_V(v):
return v[..., :, :(S - 2)]

def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0].conj()
return u00 * v00_conj

test_cases1 = ( # some=True (default)
# loss functions for complex-valued svd have to be "gauge invariant",
# i.e. loss functions shouldn't change when sigh of the singular vectors change.
Expand All @@ -878,12 +888,10 @@ def slice_V(v):
lambda usv: abs(usv[0])), # 'check_grad_u'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device),
lambda usv: abs(usv[2])), # 'check_grad_v'
# TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj()
# once https://github.com/pytorch/pytorch/issues/45821 is resolved
# this test is important as it checks the additional term that is non-zero only for complex-valued inputs
# and when the loss function depends both on 'u' and 'v'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device),
lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv'
uv_loss), # 'check_grad_uv'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)],
Expand Down

0 comments on commit ddf2681

Please sign in to comment.