Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch.linalg.svd #45562

Closed
wants to merge 75 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
3923d4e
linalg.svd, step 1: rename the old svd into linalg_svd, and reimpleme…
antocuni Sep 28, 2020
130cd04
rename svd_backward into linalg_svd_backward, for consistency
antocuni Sep 28, 2020
f39dc15
change the signature of linalg_svd: now it takes full_matrices=true, …
antocuni Sep 29, 2020
e16fde1
add a test for torch.linalg.svd, and write a numpy-compatible wrapper…
antocuni Sep 29, 2020
a14c2fb
WIP: the comment inside _create_U_S_VT was simply wrong: lapack and m…
antocuni Sep 29, 2020
675e122
we can't use transpose_(), else autograd complains that the results h…
antocuni Sep 30, 2020
91ec980
add a TODO
antocuni Sep 30, 2020
5886d58
use ! instead of not
antocuni Sep 30, 2020
f763c8d
partially undo commit 3923d4eab7: keep at::svd as the main function a…
antocuni Oct 9, 2020
b319768
change the return type of linalg_svd(..., compute_uv=False): we retur…
antocuni Oct 9, 2020
1817211
fix flake8
antocuni Oct 9, 2020
97662da
fix the docstring of svd, according to the discussion in issue #45821
antocuni Oct 12, 2020
25caa20
write the docstring for linalg.svd
antocuni Oct 15, 2020
f056875
fix for the complex case: torch.svd should return V but lapack comput…
antocuni Oct 15, 2020
edca9df
improve the docstring
antocuni Oct 16, 2020
0462f79
add comments to make sure we don't forget about this when we add supp…
antocuni Oct 16, 2020
e900ad2
add a test for cdouble but skip it, because it segfaults. Need to fil…
antocuni Oct 16, 2020
e8ce282
attach a docstring also to the underlying C function
antocuni Oct 16, 2020
25b6fef
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Oct 16, 2020
c7bfc07
this seems to be needed, else I get 'derivative for svd not implement…
antocuni Oct 26, 2020
9bcf00b
use dispatch: Math as per @mruberry suggestion
antocuni Oct 26, 2020
7342ece
git merge origin/master
antocuni Oct 26, 2020
25752e6
implement the out= version of torch.linalg.svd
antocuni Oct 29, 2020
10a66d7
this no longer segfaults
antocuni Oct 29, 2020
b1c3cfb
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Nov 10, 2020
6f480c2
no longer needed
antocuni Nov 10, 2020
4038ace
remove merge leftover
antocuni Nov 10, 2020
a224dbd
change the semantics of the out= param if compute_uv==False
antocuni Nov 10, 2020
abf9baf
as discussed on the PR, remove the apply_conj feature: the risk of br…
antocuni Nov 10, 2020
0c0e8c6
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Nov 12, 2020
37ec294
change again the semantics for compute_uv=False: we finally decided t…
antocuni Nov 12, 2020
60e463f
fix flake8
antocuni Nov 12, 2020
2433399
s/self/input in error messages
antocuni Nov 26, 2020
4b4076d
kill this test, now the behavior is tested directly in python
antocuni Nov 26, 2020
5fa4efe
move the svd tests from test_torch.py to test_linalg.py
antocuni Nov 26, 2020
f16f6e7
improve the docs of torch.svd
antocuni Nov 26, 2020
a008c53
try to improve the docs
antocuni Nov 26, 2020
989505f
rephrase
antocuni Nov 26, 2020
80e18d2
fix
antocuni Nov 27, 2020
71d4db1
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 7, 2020
c9c60c2
improve the docs of torch.svd in the same way we did for torch.linalg…
antocuni Dec 7, 2020
7b70521
kill duplicate tests: what happened is that both upstream/master and …
antocuni Dec 7, 2020
e25de98
refactor and fix test_namedtuple_return_api.py
antocuni Dec 8, 2020
5dc359d
mark the changes to _svd_helper as intentional
antocuni Dec 9, 2020
50a28ae
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 9, 2020
8b30bcb
Improve the error message generated by test_overrides.
antocuni Dec 10, 2020
b87b765
add the new torch.linalg.svd to test_overrides
antocuni Dec 10, 2020
22b17c0
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 14, 2020
cb9de75
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 21, 2020
a349569
this is needed after e391dbc1b5
antocuni Dec 21, 2020
e30248e
this doesn't have to be a method
antocuni Dec 21, 2020
3ddb6ac
use torch.empty instead of torch.tensor
antocuni Dec 21, 2020
823e6a8
remove unnecessary import
antocuni Dec 21, 2020
52aadbe
use the proper @dtypes decorator
antocuni Dec 21, 2020
ea0aca4
don't use Tensor
antocuni Dec 21, 2020
9bc46f6
typo
antocuni Dec 21, 2020
da1f2ae
improve docs
antocuni Dec 21, 2020
5445d3f
use the correct nccl version (hopefully)
antocuni Dec 21, 2020
6d04a6e
now the underlying op is _svd_helper, and svd is only a thin layer on…
antocuni Dec 21, 2020
a2e2781
Revert "Improve the error message generated by test_overrides."
antocuni Dec 21, 2020
bcf7461
add an OpInfo for torch.linalg.svd, and adapt sample_inputs_svd to ge…
antocuni Dec 22, 2020
ba92c1a
typo
antocuni Dec 22, 2020
fa927d0
git merge upstream/viable/strict
antocuni Dec 22, 2020
a866c67
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Dec 22, 2020
9d9cd03
specify the aten_name
antocuni Dec 23, 2020
c3e4de6
fix indent
antocuni Dec 23, 2020
ec87163
fix flake8
antocuni Dec 23, 2020
958321e
fix test_namedtuple_return: with the new logic, the 'a' argument is p…
antocuni Dec 23, 2020
944fa6a
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Jan 5, 2021
8dbbfe5
input/svd input
antocuni Jan 5, 2021
be456ab
remove redundant sentences
antocuni Jan 5, 2021
e483647
check that the linalg.svd output tensors are on the correct device
antocuni Jan 5, 2021
24c6506
flake8
antocuni Jan 5, 2021
c39f2ef
skip this test if we don't have magma
antocuni Jan 5, 2021
81510d5
Merge remote-tracking branch 'upstream/master' into antocuni/linalg-svd
antocuni Jan 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 54 additions & 7 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1411,7 +1411,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-1, 0, k);
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
antocuni marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
VT_working_copy.zero_();
Expand All @@ -1421,24 +1421,71 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
U_working_copy.zero_();
VT_working_copy.zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
"svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
return at::_svd_helper(self, some, compute_uv);
}

std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& VT,
std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
const Tensor& self, bool some, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Tensor U_tmp, S_tmp, VT_tmp;
std::tie(U_tmp, S_tmp, VT_tmp) = at::_svd_helper(self, some, compute_uv);
"svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Tensor U_tmp, S_tmp, V_tmp;
std::tie(U_tmp, S_tmp, V_tmp) = at::_svd_helper(self, some, compute_uv);
U.resize_as_(U_tmp).copy_(U_tmp);
antocuni marked this conversation as resolved.
Show resolved Hide resolved
S.resize_as_(S_tmp).copy_(S_tmp);
VT.resize_as_(VT_tmp).copy_(VT_tmp);
V.resize_as_(V_tmp).copy_(V_tmp);
return std::tuple<Tensor&, Tensor&, Tensor&>(U, S, V);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/* torch.linalg.svd, implemented in terms of torch.svd. There are two main
differences:

1. the 2nd parameter is bool some=True, which if effectively the opposite
of full_matrices=True

2. svd returns V, while linalg.svd returns VT. To accommodate the
difference, we transpose() V upon return
*/

std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) {
TORCH_CHECK(self.dim() >= 2,
"svd input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");

bool some = !full_matrices;
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
return std::make_tuple(U, S, VT);
} else {
Tensor empty_U = at::empty({0}, self.options());
Tensor empty_VT = at::empty({0}, self.options());
return std::make_tuple(empty_U, S, empty_VT);
}
}

static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) {
TORCH_CHECK(src.device() == dst.device(), "svd output tensor ", name, " is on the wrong device: expected ", src.device(), " got ", dst.device());
at::native::resize_output(dst, src.sizes());
dst.copy_(src);
}

std::tuple<Tensor&, Tensor&, Tensor&> linalg_svd_out(Tensor& U, Tensor& S, Tensor& VT,
const Tensor& self, bool full_matrices, bool compute_uv) {
Tensor U_tmp, S_tmp, VT_tmp;
std::tie(U_tmp, S_tmp, VT_tmp) = at::linalg_svd(self, full_matrices, compute_uv);
svd_resize_and_copy("U", U_tmp, U);
svd_resize_and_copy("S", S_tmp, S);
svd_resize_and_copy("V", VT_tmp, VT);
return std::tuple<Tensor&, Tensor&, Tensor&>(U, S, VT);
}

Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -261,18 +261,21 @@ static inline std::tuple<Tensor, Tensor, Tensor> _create_U_S_VT(const Tensor& in
U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
}

// VT should be a column-major or a batch of column-major matrices
sizes[input.dim() - 2] = n;
sizes[input.dim() - 1] = n;
// VT should be a row-major or a batch of row-major matrices
strides = at::detail::defaultStrides(sizes);
strides[input.dim() - 1] = n;
strides[input.dim() - 2] = 1;
Tensor VT_empty;
if (!input.is_cuda()) {
VT_empty = at::empty(sizes, input.options());
VT_empty = at::empty_strided(sizes, strides, input.options());
} else {
// NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd
// (which is the driver routine for the divide and conquer SVD operation)
// takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that
// moves the inputs between devices internally.
VT_empty = at::empty(sizes, input.options().device(at::kCPU));
VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU));
}

sizes.pop_back();
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -2194,7 +2194,7 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-1, 0, k);
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
} else {
VT_working_copy.zero_();
Expand All @@ -2205,6 +2205,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device()));
VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}

Expand Down
15 changes: 12 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -5820,14 +5820,14 @@
- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
dispatch:
DefaultBackend: svd_out
Math: svd_out

- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
variants: method, function
dispatch:
DefaultBackend: svd
Math: svd

- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor)
- func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor U, Tensor S, Tensor V)
variants: function
dispatch:
CPU: _svd_helper_cpu
Expand Down Expand Up @@ -8962,6 +8962,15 @@
python_module: linalg
variants: function

- func: linalg_svd.U(Tensor self, bool full_matrices=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: linalg

- func: linalg_svd(Tensor self, bool full_matrices=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
python_module: linalg
use_c10_dispatcher: full
variants: function

- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor
python_module: linalg
variants: function
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Expand Up @@ -19,6 +19,7 @@ Functions
.. autofunction:: eigvalsh
.. autofunction:: matrix_rank
.. autofunction:: norm
.. autofunction:: svd
.. autofunction:: solve
.. autofunction:: tensorinv
.. autofunction:: tensorsolve
Expand Down
Expand Up @@ -37,6 +37,7 @@
("aten::ifft", datetime.date(2021, 1, 31)),
("aten::irfft", datetime.date(2021, 1, 31)),
("aten::rfft", datetime.date(2021, 1, 31)),
("aten::_svd_helper", datetime.date(2021, 1, 31)),
("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
("aten::_cudnn_rnn", datetime.date(2020, 12, 31)),
("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)),
Expand Down