Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch.linalg.svd #45562

Closed
wants to merge 75 commits into from
Closed
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
3923d4e
linalg.svd, step 1: rename the old svd into linalg_svd, and reimpleme…
antocuni Sep 28, 2020
130cd04
rename svd_backward into linalg_svd_backward, for consistency
antocuni Sep 28, 2020
f39dc15
change the signature of linalg_svd: now it takes full_matrices=true, …
antocuni Sep 29, 2020
e16fde1
add a test for torch.linalg.svd, and write a numpy-compatible wrapper…
antocuni Sep 29, 2020
a14c2fb
WIP: the comment inside _create_U_S_VT was simply wrong: lapack and m…
antocuni Sep 29, 2020
675e122
we can't use transpose_(), else autograd complains that the results h…
antocuni Sep 30, 2020
91ec980
add a TODO
antocuni Sep 30, 2020
5886d58
use ! instead of not
antocuni Sep 30, 2020
f763c8d
partially undo commit 3923d4eab7: keep at::svd as the main function a…
antocuni Oct 9, 2020
b319768
change the return type of linalg_svd(..., compute_uv=False): we retur…
antocuni Oct 9, 2020
1817211
fix flake8
antocuni Oct 9, 2020
97662da
fix the docstring of svd, according to the discussion in issue #45821
antocuni Oct 12, 2020
25caa20
write the docstring for linalg.svd
antocuni Oct 15, 2020
f056875
fix for the complex case: torch.svd should return V but lapack comput…
antocuni Oct 15, 2020
edca9df
improve the docstring
antocuni Oct 16, 2020
0462f79
add comments to make sure we don't forget about this when we add supp…
antocuni Oct 16, 2020
e900ad2
add a test for cdouble but skip it, because it segfaults. Need to fil…
antocuni Oct 16, 2020
e8ce282
attach a docstring also to the underlying C function
antocuni Oct 16, 2020
25b6fef
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Oct 16, 2020
c7bfc07
this seems to be needed, else I get 'derivative for svd not implement…
antocuni Oct 26, 2020
9bcf00b
use dispatch: Math as per @mruberry suggestion
antocuni Oct 26, 2020
7342ece
git merge origin/master
antocuni Oct 26, 2020
25752e6
implement the out= version of torch.linalg.svd
antocuni Oct 29, 2020
10a66d7
this no longer segfaults
antocuni Oct 29, 2020
b1c3cfb
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Nov 10, 2020
6f480c2
no longer needed
antocuni Nov 10, 2020
4038ace
remove merge leftover
antocuni Nov 10, 2020
a224dbd
change the semantics of the out= param if compute_uv==False
antocuni Nov 10, 2020
abf9baf
as discussed on the PR, remove the apply_conj feature: the risk of br…
antocuni Nov 10, 2020
0c0e8c6
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Nov 12, 2020
37ec294
change again the semantics for compute_uv=False: we finally decided t…
antocuni Nov 12, 2020
60e463f
fix flake8
antocuni Nov 12, 2020
2433399
s/self/input in error messages
antocuni Nov 26, 2020
4b4076d
kill this test, now the behavior is tested directly in python
antocuni Nov 26, 2020
5fa4efe
move the svd tests from test_torch.py to test_linalg.py
antocuni Nov 26, 2020
f16f6e7
improve the docs of torch.svd
antocuni Nov 26, 2020
a008c53
try to improve the docs
antocuni Nov 26, 2020
989505f
rephrase
antocuni Nov 26, 2020
80e18d2
fix
antocuni Nov 27, 2020
71d4db1
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 7, 2020
c9c60c2
improve the docs of torch.svd in the same way we did for torch.linalg…
antocuni Dec 7, 2020
7b70521
kill duplicate tests: what happened is that both upstream/master and …
antocuni Dec 7, 2020
e25de98
refactor and fix test_namedtuple_return_api.py
antocuni Dec 8, 2020
5dc359d
mark the changes to _svd_helper as intentional
antocuni Dec 9, 2020
50a28ae
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 9, 2020
8b30bcb
Improve the error message generated by test_overrides.
antocuni Dec 10, 2020
b87b765
add the new torch.linalg.svd to test_overrides
antocuni Dec 10, 2020
22b17c0
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 14, 2020
cb9de75
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 21, 2020
a349569
this is needed after e391dbc1b5
antocuni Dec 21, 2020
e30248e
this doesn't have to be a method
antocuni Dec 21, 2020
3ddb6ac
use torch.empty instead of torch.tensor
antocuni Dec 21, 2020
823e6a8
remove unnecessary import
antocuni Dec 21, 2020
52aadbe
use the proper @dtypes decorator
antocuni Dec 21, 2020
ea0aca4
don't use Tensor
antocuni Dec 21, 2020
9bc46f6
typo
antocuni Dec 21, 2020
da1f2ae
improve docs
antocuni Dec 21, 2020
5445d3f
use the correct nccl version (hopefully)
antocuni Dec 21, 2020
6d04a6e
now the underlying op is _svd_helper, and svd is only a thin layer on…
antocuni Dec 21, 2020
a2e2781
Revert "Improve the error message generated by test_overrides."
antocuni Dec 21, 2020
bcf7461
add an OpInfo for torch.linalg.svd, and adapt sample_inputs_svd to ge…
antocuni Dec 22, 2020
ba92c1a
typo
antocuni Dec 22, 2020
fa927d0
git merge upstream/viable/strict
antocuni Dec 22, 2020
a866c67
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 22, 2020
9d9cd03
specify the aten_name
antocuni Dec 23, 2020
c3e4de6
fix indent
antocuni Dec 23, 2020
ec87163
fix flake8
antocuni Dec 23, 2020
958321e
fix test_namedtuple_return: with the new logic, the 'a' argument is p…
antocuni Dec 23, 2020
944fa6a
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Jan 5, 2021
8dbbfe5
input/svd input
antocuni Jan 5, 2021
be456ab
remove redundant sentences
antocuni Jan 5, 2021
e483647
check that the linalg.svd output tensors are on the correct device
antocuni Jan 5, 2021
24c6506
flake8
antocuni Jan 5, 2021
c39f2ef
skip this test if we don't have magma
antocuni Jan 5, 2021
81510d5
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Jan 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 44 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1046,7 +1046,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-1, 0, k);
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
antocuni marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
VT_working_copy.zero_();
Expand All @@ -1056,6 +1056,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
U_working_copy.zero_();
VT_working_copy.zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

Expand All @@ -1065,12 +1067,51 @@ std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compu
return at::_svd_helper(self, some, compute_uv);
}

std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& VT,
std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
const Tensor& self, bool some, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Tensor U_tmp, S_tmp, V_tmp;
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv);
U.resize_as_(U_tmp).copy_(U_tmp);
antocuni marked this conversation as resolved.
Show resolved Hide resolved
S.resize_as_(S_tmp).copy_(S_tmp);
V.resize_as_(V_tmp).copy_(V_tmp);
return std::tuple<Tensor&, Tensor&, Tensor&>(U, S, V);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/* torch.linalg.svd, implemented in terms of torch.svd. There are two main
differences:

1. the 2nd parameter is bool some=True, which if effectively the opposite
of full_matrices=True

2. svd returns V, while linalg.svd returns VT. To accommodate the
difference, we transpose() V upon return
*/

std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
antocuni marked this conversation as resolved.
Show resolved Hide resolved

bool some = !full_matrices;
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
return std::make_tuple(U, S, VT);
} else {
Tensor empty_U = at::empty({0}, self.options());
Tensor empty_VT = at::empty({0}, self.options());
return std::make_tuple(empty_U, S, empty_VT);
}
}

std::tuple<Tensor&, Tensor&, Tensor&> linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT,
const Tensor& self, bool full_matrices, bool compute_uv) {
Tensor U_tmp, S_tmp, VT_tmp;
std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv);
std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv);
U.resize_as_(U_tmp).copy_(U_tmp);
Copy link
Collaborator

@mruberry mruberry Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See note on error-checking out above. Here we should require same dtype and same device for now, and use resize_output instead of resize_as_.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in commit e483647.
I noticed that resize_output doesn't check the device, so I wrote the check manually: so, what is the convention for output tensors? Is it generally allowed to pass output tensors on different devices? If not, why the check is not done by resize_output itself?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellent question. I actually wrote a brief description of out= handling in the Developer FAQ: https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch. I think it's currently a gap in the tools PyTorch provides that we require some operations to implement this check manually.

@ezyang is actually developing a new architecture that I think solves this issue. Maybe we should extend resize_output, too. Currently it doesn't have the device information necessary to implement this check. @heitorschueroff actually had that same idea recently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the answer, I have a better picture now

S.resize_as_(S_tmp).copy_(S_tmp);
VT.resize_as_(VT_tmp).copy_(VT_tmp);
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -241,18 +241,21 @@ static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& in
U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
}

// VT should be a column-major or a batch of column-major matrices
sizes[input.dim() - 2] = n;
sizes[input.dim() - 1] = n;
// VT should be a row-major or a batch of row-major matrices
strides = at::detail::defaultStrides(sizes);
strides[input.dim() - 1] = n;
strides[input.dim() - 2] = 1;
Tensor VT_empty;
if (!input.is_cuda()) {
VT_empty = at::empty(sizes, input.options());
VT_empty = at::empty_strided(sizes, strides, input.options());
} else {
// NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd
// (which is the driver routine for the divide and conquer SVD operation)
// takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that
// moves the inputs between devices internally.
VT_empty = at::empty(sizes, input.options().device(at::kCPU));
VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
}

sizes.pop_back();
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1844,7 +1844,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-1, 0, k);
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
} else {
VT_working_copy.zero_();
Expand All @@ -1855,6 +1855,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device()));
VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

Expand Down
14 changes: 11 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6196,15 +6196,15 @@

- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
dispatch:
DefaultBackend: svd_out
Math: svd_out

- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
use_c10_dispatcher: full
variants: method, function
dispatch:
DefaultBackend: svd
Math: svd

- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor)
- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V)
use_c10_dispatcher: full
variants: function
dispatch:
Expand Down Expand Up @@ -8954,6 +8954,14 @@
python_module: linalg
variants: function

- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
python_module: linalg

- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
python_module: linalg
use_c10_dispatcher: full
variants: method, function
antocuni marked this conversation as resolved.
Show resolved Hide resolved

- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor
python_module: linalg
variants: function
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Expand Up @@ -14,4 +14,5 @@ Functions

.. autofunction:: det
.. autofunction:: norm
.. autofunction:: svd
.. autofunction:: tensorsolve
13 changes: 13 additions & 0 deletions test/cpp/api/functional.cpp
Expand Up @@ -2858,3 +2858,16 @@ TEST_F(FunctionalTest, BCEWithLogitsLoss) {
ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
}
}

TEST_F(FunctionalTest, linalg_svd) {
// NOTE: this is only a partial test: it tests that when we pass
// compute_uv=False, the returned U and VT are empty tensors. We need to
// write a C++ test because in Python it has a slightly different behavior
// and it returns (None, S, None) instead. The full logic for svd is
antocuni marked this conversation as resolved.
Show resolved Hide resolved
// tested thoughtfully in Python.
const auto input = torch::rand({7, 3});
torch::Tensor U, S, VT;
std::tie(U, S, VT) = at::linalg_svd(input, true, false);
ASSERT_EQ(U.numel(), 0) << "U is not empty";
ASSERT_EQ(VT.numel(), 0) << "VT is not empty";
}
52 changes: 52 additions & 0 deletions test/test_linalg.py
Expand Up @@ -1019,6 +1019,58 @@ def test_nuclear_norm_exceptions_old(self, device):
self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_svd_compute_uv(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test and test_svd_no_compute_uv need to test more sizes, include a batched input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it? The actual logic of svd is inside _svd_helper, which is already tested explicitly by _test_svd_helper and all the tests which calls it. In my idea, the test_linalg_svd_* are meant to test the only actual differences, i.e. the signature and return types.
But I am fine to duplicate all the tests if you think it's a good idea.

"""
Test the default case, compute_uv=True. Here we have the very same behavior as
numpy
"""
t = torch.randn((10, 11), device=device, dtype=dtype)
np_t = t.cpu().numpy()
for full_matrices in (True, False):
# check linalg.svd vs numpy
expected = np.linalg.svd(np_t, full_matrices, compute_uv=True)
actual = torch.linalg.svd(t, full_matrices, compute_uv=True)
self.assertEqual(actual, expected)
# check linalg.svd vs linalg.svd(out=...)
out = (torch.empty_like(actual[0]),
torch.empty_like(actual[1]),
torch.empty_like(actual[2]))
out2 = torch.linalg.svd(t, full_matrices, compute_uv=True, out=out)
self.assertEqual(actual, out)
self.assertEqual(actual, out2)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_svd_no_compute_uv(self, device, dtype):
"""
Test the compute_uv=False case. Here we have a different return type than
numpy: numpy returns S, we return (empty, S, empty)
"""
t = torch.randn((10, 11), device=device, dtype=dtype)
np_t = t.cpu().numpy()

def is_empty(x):
return x.numel() == 0 and x.dtype == t.dtype and x.device == t.device

for full_matrices in (True, False):
# check linalg.svd vs numpy
np_s = np.linalg.svd(np_t, full_matrices, compute_uv=False)
USV = torch.linalg.svd(t, full_matrices, compute_uv=False)
assert is_empty(USV.U)
self.assertEqual(USV.S, np_s)
assert is_empty(USV.V)
# check linalg.svd vs linalg.svd(out=...)
out = (torch.Tensor(), torch.empty_like(USV.S), torch.Tensor())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use torch.Tensor(), which is deprecated. Instead use torch.tensor/torch.empty/torch.empty_like.

Does this mean that an empty CPU tensor is an acceptable out= value when running on CUDA? torch.linalg.svd should verify that the tensors passed to out= are valid whether compute_uv is True or False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use torch.Tensor(), which is deprecated. Instead use torch.tensor/torch.empty/torch.empty_like.
ea0aca4

Does this mean that an empty CPU tensor is an acceptable out= value when running on CUDA? torch.linalg.svd should verify that the tensors passed to out= are valid whether compute_uv is True or False.

currently yes, but that's true for most (if not all) operators defined BatchLinearAlgebra.cpp. E.g. torch.eig, torch.svd, torch.solve, and in general all operators which uses the pattern out.resize(X).copy_(...).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes. I think you're right and that's a bug. Let's get this case correct and I've updated #49468 to include a bullet point for testing this behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the best way to implement this behavior would be to add this extra check to resize_output, and make sure to use it everywhere.
I'm happy to work on this but I would prefer to do this in a separate PR, since it probably involves touching a lot of different code, introduce new tests, etc.

USV = torch.linalg.svd(t, full_matrices, compute_uv=False, out=out)
assert USV.U is out[0]
assert USV.S is out[1]
assert USV.V is out[2]
self.assertEqual(USV.S, np_s)

@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
Expand Down
13 changes: 13 additions & 0 deletions test/test_torch.py
Expand Up @@ -10023,6 +10023,19 @@ def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **op
guess_rank, actual_rank, size, batches = 2, 2, (17, 4), ()
run_subtest(guess_rank, actual_rank, size, batches, device, jitted)

@onlyCPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the svd tests should go in test_linalg.py, even the ones for torch.svd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 5fa4efe.
However it is a bit weird now: there is test_linalg.py:test_svd which does NOT test torch.linalg.svd but it tests torch.svd.
I tried to improve the situation by renaming the new tests into test_linalg_svd_* and to put comments to divide the two sections.

@skipCPUIfNoLapack
@dtypes(torch.cfloat)
def test_svd_complex(self, device, dtype):
t = torch.randn((10, 10), dtype=dtype, device=device)
U, S, V = torch.svd(t, some=False)
# note: from the math point of view, it is weird that we need to use
# V.T instead of V.T.conj(): torch.svd has a buggy behavior for
# complex numbers and it's deprecated. You should use torch.linalg.svd
# instead.
t2 = U @ torch.diag(S).type(dtype) @ V.T
self.assertEqual(t, t2)

def test_lerp(self, device):
start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)]
for shapes in product(start_end_shapes, start_end_shapes):
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Expand Up @@ -1031,7 +1031,7 @@
- name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)

- name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
- name: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V)
self: svd_backward(grads, self, some, compute_uv, U, S, V)

- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
Expand Down
19 changes: 18 additions & 1 deletion torch/_torch_docs.py
Expand Up @@ -7662,7 +7662,24 @@ def merge_dicts(*dicts):

This function returns a namedtuple ``(U, S, V)`` which is the singular value
decomposition of a input real matrix or batches of real matrices :attr:`input` such that
:math:`input = U \times diag(S) \times V^T`.
:math:`input = U \times diag(S) \times V^T`, where :math:`V^T` is the transpose
of ``V``.

The original tensor can be reconstructed by::
antocuni marked this conversation as resolved.
Show resolved Hide resolved

U @ diag(S) @ V.T


.. note:: It is worth noting that that the code above works unmodified even
antocuni marked this conversation as resolved.
Show resolved Hide resolved
for complex numbers, i.e. the returned matrix ``V`` is already
conjugated. This behavior is probably unexpected from the
mathematical point of view, but it is not possible to change it
without breaking existing code. New code is encouraged to use
``torch.linalg.svd`` instead, which returns :math:`V^H` instead.


The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of
``S`` is always real numbers, even if ``input`` is complex.
antocuni marked this conversation as resolved.
Show resolved Hide resolved

If :attr:`some` is ``True`` (default), the method returns the reduced
singular value decomposition i.e., if the last two dimensions of
Expand Down
105 changes: 105 additions & 0 deletions torch/linalg/__init__.py
Expand Up @@ -140,6 +140,111 @@
(tensor(3.7417), tensor(11.2250))
""")

svd = _add_docstr(_linalg.linalg_svd, r"""
linalg.svd(input, full_matrices=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor)
antocuni marked this conversation as resolved.
Show resolved Hide resolved

This function returns a namedtuple ``(U, S, Vh)`` which is the singular value
antocuni marked this conversation as resolved.
Show resolved Hide resolved
decomposition of a input real matrix or batches of real matrices :attr:`input` such that
:math:`input = U \times diag(S) \times V^H` (where :math:`V^H` is ``Vh``).

.. warning:: **Differences with** :meth:`~torch.svd`:
antocuni marked this conversation as resolved.
Show resolved Hide resolved

* :attr:`full_matrices` is the opposite of
:meth:`~torch.svd`'s :attr:`some`. Note that default value
for both is ``True``, so the default behavior is effectively
the opposite.

* it returns ``Vh``, whereas :meth:`~torch.svd` returns
``V``. The result is that when using :meth:`~torch.svd` you
need to manually transpose and conjugate ``V`` in order to
reconstruct the original matrix.

* If :attr:`compute_uv=False`, it returns empty tensors (i.e.,
with 0 elements) for ``U`` and ``V``, whereas
:meth:`~torch.svd` returns zero-filled tensors.

**Differences with** ``numpy.linalg.svd``:

* if :attr:`compute_uv=False` it returns ``(empty_tensor, S, empty_tensor)``,
whereas numpy returns ``S``.


The dtype of ``U`` and ``V`` is the same as the ``input`` matrix. The dtype of
antocuni marked this conversation as resolved.
Show resolved Hide resolved
``S`` is always real numbers, even if ``input`` is complex.

If :attr:`full_matrices` is ``False``, the method returns the reduced singular value decomposition
antocuni marked this conversation as resolved.
Show resolved Hide resolved
i.e., if the last two dimensions of :attr:`input` are ``m`` and ``n``, then the returned
`U` and `V` matrices will contain only :math:`min(n, m)` orthonormal columns.

If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be None.:attr:`full_matrices` will
antocuni marked this conversation as resolved.
Show resolved Hide resolved
be ignored here.

.. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices,
then the singular values of each matrix in the batch is returned in descending order.

.. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer
algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine
`gesdd` as well.
Copy link
Collaborator

@mruberry mruberry Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This note could be:

"PyTorch's implementation of SVD uses LAPACK's (on CPU) or MAGMA's (on CUDA) `?gesdd` (a divide-and-conquer algorithm) instead of `?gesvd` for speed."

What's up with the question marks before gesdd and gesvd?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the information.
MKL documentation for LAPACK uses the question marks in place of a letter indicating what datatype the function operates on ({'s','d','c','z'} -> '?').
https://software.intel.com/content/www/us/en/develop/documentation/mkl-developer-reference-fortran/top/lapack-routines/naming-conventions-for-lapack-routines.html
Netlib documentation for LAPACK the first datatype letter is replaced with 'x'
https://www.netlib.org/lapack/lug/node24.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I suppose we could remove the question mark or replace it with the types we support in braces, like: {a, b}gesdd, but the current presentation also seems fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick search in the docs seems to suggest that we are not using a consistent naming scheme. E.g.:

  • cholesky_inverse explicitly lists all variations by saying dpotri and spotri

  • toch.geqrf mentions both geqrf and ?geqrf in the same docstring

  • orgqr the same

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not surprising we're inconsistent, and since we're inconsistent this PR can do whatever you like.

cc @heitorschueroff, too. We should review the linear algebra documentation for consistency ahead of the 1.8 release.


.. note:: Irrespective of the original strides, the returned matrix `U`
antocuni marked this conversation as resolved.
Show resolved Hide resolved
will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()`

.. note:: Extra care needs to be taken when backward through `U` and `V`
antocuni marked this conversation as resolved.
Show resolved Hide resolved
outputs. Such operation is really only stable when :attr:`input` is
full rank with all distinct singular values. Otherwise, ``NaN`` can
appear as the gradients are not properly defined. Also, notice that
double backward will usually do an additional backward through `U` and
`V` even if the original backward is only on `S`.

.. note:: When :attr:`full_matrices` = ``False``, the gradients on :code:`U[..., :, min(m, n):]`
rgommers marked this conversation as resolved.
Show resolved Hide resolved
and :code:`V[..., :, min(m, n):]` will be ignored in backward as those vectors
can be arbitrary bases of the subspaces.

.. note:: When :attr:`compute_uv` = ``False``, backward cannot be performed since `U` and `V`
antocuni marked this conversation as resolved.
Show resolved Hide resolved
from the forward pass is required for the backward operation.

Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of :math:`m \times n` matrices.
full_matrices (bool, optional): controls the shape of returned `U` and `V`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is technically true that the "full_matrices" argument affects the shape of `U` and `V`, but there's probably a better way to describe its effect. This also needs to include the default value (see below).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to improve but I'm not sure I like the result:

    full_matrices (bool, optional): controls whether to compute the full or reduced decomposition, and
                                    consequently the shape of returned ``U`` and ``V``. Defaults to True.

compute_uv (bool, optional): option whether to compute `U` and `V` or not
antocuni marked this conversation as resolved.
Show resolved Hide resolved
out (tuple, optional): the output tuple of tensors. If compute_uv=False, tThe 1st and 3rd
antocuni marked this conversation as resolved.
Show resolved Hide resolved
argument must be tensors, but they are ignored. E.g. you can
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"argument" -> "arguments"

This is interesting. So this function will not set them to be empty tensors in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as of the current code, yes. It seems weird to check and resize the shape of them if they are ultimately ignored.

pass `(torch.Tensor(), out_S, torch.Tensor())`

Example::

>>> import torch
>>> a = torch.randn(5, 3)
>>> a
tensor([[-0.3357, -0.2987, -1.1096],
[ 1.4894, 1.0016, -0.4572],
[-1.9401, 0.7437, 2.0968],
[ 0.1515, 1.3812, 1.5491],
[-1.8489, -0.5907, -2.5673]])
>>>
>>> # reconstruction in the full_matrices=False case
>>> u, s, vh = torch.linalg.svd(a, full_matrices=False)
>>> u.shape, s.shape, vh.shape
(torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3]))
>>> torch.dist(a, u @ torch.diag(s) @ vh)
tensor(1.0486e-06)
>>>
>>> # reconstruction in the full_matrices=True case
>>> u, s, vh = torch.linalg.svd(a)
>>> u.shape, s.shape, vh.shape
(torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3]))
>>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh)
>>> torch.dist(a, u[:, :3] @ torch.diag(s) @ vh)
tensor(1.0486e-06)
>>>
>>> # extra dimensions
>>> a_big = torch.randn(7, 5, 3)
>>> u, s, vh = torch.linalg.svd(a_big, full_matrices=False)
>>> torch.dist(a_big, u @ torch.diag_embed(s) @ vh)
tensor(3.0957e-06)
""")

tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r"""
linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor

Expand Down