Skip to content

Commit

Permalink
Clean up rank so that it uses Spinv and revert change in init.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZelboK committed May 14, 2024
1 parent 4eab1c3 commit fec9793
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3437,10 +3437,11 @@ static void linalg_lstsq_out_info(
} else {
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd");
rank = at::zeros({1}, at::kLong);
rank[0] = (S > rcond).sum();

auto S_pinv = S.reciprocal();
auto s1 = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
S_pinv.masked_fill_(S < rcond * s1, 0);
rank[0] = (S != 0).sum();
auto uhOther = at::matmul(U.adjoint(), other);
if(S_pinv.dim() != uhOther.dim()) {
S_pinv = S_pinv.unsqueeze(-1);
Expand All @@ -3458,7 +3459,7 @@ static void linalg_lstsq_out_info(
bool compute_residuals = true;
if (driver == "gelss" || driver == "gelsd") {
if (input.dim() == 2) {
compute_residuals = (rank.item().toDouble() == n);
compute_residuals = (rank.item().toInt() == n);
} else {
// it is not clear what to do if some matrices have rank < n in case of batched input
// For now let's compute the residuals only if all matrices have rank equal to n
Expand Down
2 changes: 1 addition & 1 deletion third_party/ideep
Submodule ideep updated 1 files
+1 −1 mkl-dnn
4 changes: 2 additions & 2 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@
- `'gelsd'` (tridiagonal reduction and SVD)
- But if you run into memory issues: `'gelss'` (full SVD).
For CUDA inputs, two drivers are available: 'gels' and 'gelss'.
For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank.
See also the `full description of these drivers`_
Expand Down Expand Up @@ -1080,7 +1080,7 @@
Keyword args:
driver (str, optional): name of the LAPACK/MAGMA method to be used.
If `None`, `'gelsy'` is used for CPU inputs, `'gels'` and `'gelss'` for CUDA inputs.
If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs.
Default: `None`.
Returns:
Expand Down

0 comments on commit fec9793

Please sign in to comment.