Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch.svd return V, not V.conj() for complex inputs #51012

Closed
wants to merge 10 commits into from
8 changes: 5 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1566,6 +1566,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
VT_working_copy.zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
// Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V.
VT_working_copy = VT_working_copy.conj();
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}
Expand Down Expand Up @@ -1596,8 +1598,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
1. the 2nd parameter is bool some=True, which if effectively the opposite
of full_matrices=True

2. svd returns V, while linalg.svd returns VT. To accommodate the
difference, we transpose() V upon return
2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs).
To accommodate the difference, we transpose() and conj() V upon return
*/

std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) {
Expand All @@ -1608,7 +1610,7 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matr
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
Tensor VT = V.conj().transpose(-2, -1);
return std::make_tuple(U, S, VT);
} else {
Tensor empty_U = at::empty({0}, self.options());
Expand Down
12 changes: 5 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {

// If not Hermitian use singular value decomposition, else use eigenvalue decomposition
if (!hermitian) {
// until https://github.com/pytorch/pytorch/issues/45821 is resolved
// svd() returns conjugated V for complex-valued input
Tensor U, S, V_conj;
Tensor U, S, V;
// TODO: replace input.svd with linalg_svd
std::tie(U, S, V_conj) = input.svd();
// using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
std::tie(U, S, V) = input.svd();
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
// computes V @ diag(S_pseudoinv) @ U.T.conj()
// TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved
return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
// computes V @ diag(S_pseudoinv) @ U.conj().T
return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
} else {
Tensor S, U;
std::tie(S, U) = at::linalg_eigh(input);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -2252,6 +2252,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
// Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V.
VT_working_copy = VT_working_copy.conj();
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}
Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu
Expand Up @@ -204,8 +204,6 @@ inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] {
_apply_svd_lib_gesvdj<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv, some);
});

VT = VT.conj();
}

// call cusolver gesvdj batched function to calculate svd
Expand Down Expand Up @@ -256,8 +254,6 @@ inline static void apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Te
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdjBatched", [&] {
_apply_svd_lib_gesvdjBatched<scalar_t>(self_working_copy, U, S, VT, infos, compute_uv);
});

VT = VT.conj();
}

// entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched
Expand Down
13 changes: 7 additions & 6 deletions test/test_linalg.py
Expand Up @@ -1873,17 +1873,18 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option
actual_rank, size, batches = 2, (17, 4), ()
run_subtest(actual_rank, size, batches, device, jitted)

@onlyCPU
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.cfloat)
def test_svd_complex(self, device, dtype):
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
# this test verifies that torch.svd really returns V and not V.conj()
# see: https://github.com/pytorch/pytorch/issues/45821
t = torch.randn((10, 10), dtype=dtype, device=device)
U, S, V = torch.svd(t, some=False)
# note: from the math point of view, it is weird that we need to use
# V.T instead of V.T.conj(): torch.svd has a buggy behavior for
# complex numbers and it's deprecated. You should use torch.linalg.svd
# instead.
t2 = U @ torch.diag(S).type(dtype) @ V.T
# verify that t ≈ t2
# t2 = U @ diag(S) @ Vᴴ
# Vᴴ is the conjugate transpose of V
t2 = U @ torch.diag(S).type(dtype) @ V.conj().T
self.assertEqual(t, t2)

def _test_svd_helper(self, shape, some, col_maj, device, dtype):
Expand Down
50 changes: 24 additions & 26 deletions torch/_torch_docs.py
Expand Up @@ -8312,47 +8312,45 @@ def merge_dicts(*dicts):
svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor)

Computes the singular value decomposition of either a matrix or batch of
matrices :attr:`input`." The singular value decomposition is represented as a
namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times
V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch
of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch
dimensions as :attr:`input`.
matrices :attr:`input`. The singular value decomposition is represented as a
namedtuple (`U,S,V`), such that
:attr:`input` = `U` diag(`S`) `Vᴴ`,
where `Vᴴ` is the transpose of `V` for the real-valued inputs,
or the conjugate transpose of `V` for the complex-valued inputs.
If :attr:`input` is a batch of tensors, then `U`, `S`, and `V` are also
batched with the same batch dimensions as :attr:`input`.

If :attr:`some` is ``True`` (default), the method returns the reduced singular
value decomposition i.e., if the last two dimensions of :attr:`input` are
``m`` and ``n``, then the returned `U` and `V` matrices will contain only
:math:`min(n, m)` orthonormal columns.
`m` and `n`, then the returned `U` and `V` matrices will contain only
min(`n, m`) orthonormal columns.

If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be
zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)`
zero-filled matrices of shape `(m × m)` and `(n × n)`
respectively, and the same device as :attr:`input`. The :attr:`some`
argument has no effect when :attr:`compute_uv` is False.
argument has no effect when :attr:`compute_uv` is ``False``.

The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will
Supports input of float, double, cfloat and cdouble data types.
The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will
always be real-valued, even if :attr:`input` is complex.

.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.``
:func:`~torch.linalg.svd` instead, which is similar to NumPy's
.. warning:: :func:`torch.svd` is deprecated. Please use
:func:`torch.linalg.svd` instead, which is similar to NumPy's
``numpy.linalg.svd``.

.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`:
.. note:: Differences with :func:`torch.linalg.svd`:

* :attr:`some` is the opposite of ``torch.linalg.``
:func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that
* :attr:`some` is the opposite of
:func:`torch.linalg.svd`'s :attr:`full_matricies`. Note that
default value for both is ``True``, so the default behavior is
effectively the opposite.

* it returns ``V``, whereas ``torch.linalg.``
:func:`~torch.linalg.svd` returns ``Vh``. The result is that
when using ``svd`` you need to manually transpose
``V`` in order to reconstruct the original matrix.
* :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns `Vᴴ`.

* If :attr:`compute_uv=False`, it returns zero-filled tensors for
``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns
* If :attr:`compute_uv=False`, :func:`torch.svd` returns zero-filled tensors for
``U`` and ``Vh``, whereas :func:`torch.linalg.svd` returns
empty tensors.

Supports real-valued and complex-valued input.

.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices,
then the singular values of each matrix in the batch is returned in descending order.

Expand Down Expand Up @@ -8382,10 +8380,10 @@ def merge_dicts(*dicts):
`U` and `V` tensors.

Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of :math:`m \times n` matrices.
input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more
batch dimensions consisting of `(m × n)` matrices.
some (bool, optional): controls whether to compute the reduced or full decomposition, and
consequently the shape of returned ``U`` and ``V``. Defaults to True.
consequently the shape of returned `U` and `V`. Defaults to True.
compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True.

Keyword args:
Expand Down
11 changes: 3 additions & 8 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1949,21 +1949,16 @@ Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const T
auto gsigma = grads[1];

auto u = raw_u;
// Currently torch.svd for complex dtypes returns the conjugate of V,
// while the backward formula is derived with just V (without the conjugation)
// therefore here we need to conjugate the V output of SVD and grads[2].
// Once https://github.com/pytorch/pytorch/issues/45821 is resolved
// extra .conj(), that are marked below in the code, shall be removed.
auto v = raw_v.conj(); // TODO: remove .conj()
auto v = raw_v;
auto gu = grads[0];
auto gv = grads[2].conj(); // TODO: remove .conj()
auto gv = grads[2];

if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(-1, 0, k);
v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj()
v = raw_v.narrow(-1, 0, k);
if (gu.defined()) {
gu = gu.narrow(-1, 0, k);
}
Expand Down
16 changes: 12 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -846,16 +846,26 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv
"""
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

# svd and linalg.svd returns V and V.T, respectively. So we need to slice
# svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice
# along different dimensions when needed (this is used by
# test_cases2:wide_all and wide_all_batched below)
if is_linalg_svd:
def slice_V(v):
return v[..., :(S - 2), :]

def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0]
return u00 * v00_conj
else:
def slice_V(v):
return v[..., :, :(S - 2)]

def uv_loss(usv):
u00 = usv[0][0, 0]
v00_conj = usv[2][0, 0].conj()
return u00 * v00_conj

test_cases1 = ( # some=True (default)
# loss functions for complex-valued svd have to be "gauge invariant",
# i.e. loss functions shouldn't change when sigh of the singular vectors change.
Expand All @@ -866,12 +876,10 @@ def slice_V(v):
lambda usv: abs(usv[0])), # 'check_grad_u'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device),
lambda usv: abs(usv[2])), # 'check_grad_v'
# TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj()
# once https://github.com/pytorch/pytorch/issues/45821 is resolved
# this test is important as it checks the additional term that is non-zero only for complex-valued inputs
# and when the loss function depends both on 'u' and 'v'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device),
lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv'
uv_loss), # 'check_grad_uv'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)],
lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide'
(random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)],
Expand Down