Skip to content

Commit

Permalink
as discussed on the PR, remove the apply_conj feature: the risk of br…
Browse files Browse the repository at this point in the history
…eaking existing code is too high. Instead, document the current behavior more accurately
  • Loading branch information
antocuni committed Nov 10, 2020
1 parent a224dbd commit abf9baf
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 30 deletions.
14 changes: 4 additions & 10 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,8 +1021,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
#endif
}

std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some,
bool compute_uv, bool apply_conj) {
std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some, bool compute_uv) {
std::vector<int64_t> infos(batchCount(self), 0);
int64_t m = self.size(-2), n = self.size(-1);
int64_t k = std::min(m, n);
Expand Down Expand Up @@ -1059,25 +1058,21 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
if (VT_working_copy.is_complex() && apply_conj)
VT_working_copy = VT_working_copy.conj();
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
bool apply_conj = true;
return at::_svd_helper(self, some, compute_uv, apply_conj);
return at::_svd_helper(self, some, compute_uv);
}

std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
const Tensor& self, bool some, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
bool apply_conj = true;
Tensor U_tmp, S_tmp, V_tmp;
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv, apply_conj);
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv);
U.resize_as_(U_tmp).copy_(U_tmp);
S.resize_as_(S_tmp).copy_(S_tmp);
V.resize_as_(V_tmp).copy_(V_tmp);
Expand All @@ -1101,9 +1096,8 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matr
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");

bool some = !full_matrices;
bool apply_conj = false;
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv, apply_conj);
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
return std::make_tuple(U, S, VT);
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1654,8 +1654,7 @@ AT_ERROR("svd: MAGMA library not found in "
#endif
}

std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool some,
bool compute_uv, bool apply_conj) {
std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) {
std::vector<int64_t> infos(batchCount(self), 0);
int64_t m = self.size(-2), n = self.size(-1);
int64_t k = std::min(m, n);
Expand Down Expand Up @@ -1708,8 +1707,6 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
if (VT_working_copy.is_complex() && apply_conj)
VT_working_copy = VT_working_copy.conj();
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6204,7 +6204,7 @@
dispatch:
Math: svd

- func: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V)
- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V)
use_c10_dispatcher: full
variants: function
dispatch:
Expand Down
6 changes: 5 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10128,7 +10128,11 @@ def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **op
def test_svd_complex(self, device, dtype):
t = torch.randn((10, 10), dtype=dtype, device=device)
U, S, V = torch.svd(t, some=False)
t2 = U @ torch.diag(S).type(dtype) @ V.T.conj()
# note: from the math point of view, it is weird that we need to use
# V.T instead of V.T.conj(): torch.svd has a buggy behavior for
# complex numbers and it's deprecated. You should use torch.linalg.svd
# instead.
t2 = U @ torch.diag(S).type(dtype) @ V.T
self.assertEqual(t, t2)

def test_lerp(self, device):
Expand Down
5 changes: 1 addition & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,7 @@
- name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)

# NOTE: currently apply_conj is ignored, but it if used only when operating on
# complex, which are not supported by svd_backward so far. When we add support
# for complex, we need to remember to consider it.
- name: _svd_helper(Tensor self, bool some, bool compute_uv, bool apply_conj) -> (Tensor U, Tensor S, Tensor V)
- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V)
self: svd_backward(grads, self, some, compute_uv, U, S, V)

- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
Expand Down
18 changes: 12 additions & 6 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7643,15 +7643,21 @@ def merge_dicts(*dicts):
This function returns a namedtuple ``(U, S, V)`` which is the singular value
decomposition of a input real matrix or batches of real matrices :attr:`input` such that
:math:`input = U \times diag(S) \times V^H`, where :math:`V^H` is the conjugate transpose
:math:`input = U \times diag(S) \times V^T`, where :math:`V^T` is the transpose
of ``V``.
In Python, :math:`V^H` is computed by ``v.T.conj()``. Note that for
non-complex types, ``.conj()`` is a no-op and can be omittted. To summarize,
the original Tensor can be reconstructed by::
The original tensor can be reconstructed by::
U @ diag(S) @ V.T
.. note:: It is worth noting that that the code above works unmodified even
for complex numbers, i.e. the returned matrix ``V`` is already
conjugated. This behavior is probably unexpected from the
mathematical point of view, but it is not possible to change it
without breaking existing code. New code is encouraged to use
``torch.linalg.svd`` instead, which returns :math:`V^H` instead.
U @ diag(S) @ V.T.conj() # for real and complex numbers
U @ diag(S) @ V.T # only for real numbers
The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of
``S`` is always real numbers, even if ``input`` is complex.
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1777,10 +1777,6 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
/* NOTE: currently this function does not support complex numbers. When we
* add support for it, we need to consider "bool apply_conj", which
* currently is ignored. See the corresponding comment in derivatives.yaml.
*/
TORCH_CHECK(compute_uv,
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
Expand Down

0 comments on commit abf9baf

Please sign in to comment.