Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added linalg.pinv #48399

Closed
wants to merge 35 commits into from
Closed

Added linalg.pinv #48399

wants to merge 35 commits into from

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Nov 23, 2020

This PR adds torch.linalg.pinv.

Changes compared to the original torch.pinverse:

  • New kwarg "hermitian": with hermitian=True eigendecomposition is used instead of singular value decomposition.
  • rcond argument can now be a Tensor of appropriate shape to apply matrix-wise clipping of singular values.
  • Added out= variant (allocates temporary and makes a copy for now)

Ref. #42666

@IvanYashchuk IvanYashchuk added module: numpy Related to numpy support, and also numpy compatibility of our operators module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Nov 23, 2020
@dr-ci
Copy link

dr-ci bot commented Nov 23, 2020

💊 CI failures summary and remediations

As of commit 48ac5dd (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 47 times.

@IvanYashchuk IvanYashchuk marked this pull request as ready for review December 9, 2020 20:55
@IvanYashchuk IvanYashchuk requested review from mruberry and removed request for glaringlee December 9, 2020 20:55
@codecov
Copy link

codecov bot commented Dec 10, 2020

Codecov Report

Merging #48399 (48ac5dd) into master (eb87686) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master   #48399   +/-   ##
=======================================
  Coverage   80.70%   80.71%           
=======================================
  Files        1904     1904           
  Lines      206598   206645   +47     
=======================================
+ Hits       166741   166792   +51     
+ Misses      39857    39853    -4     

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 10, 2020
"pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions "
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
TORCH_CHECK((at::isFloatingType(input.scalar_type()) || at::isComplexType(input.scalar_type())) && input.dim() >= 2,
"linalg_pinv(", input.scalar_type(), "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
"of floating types");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"of floating types" is misleading. The requirement is that the tensors have a floating point dtype or a complex dtype.

Also, is this the right place to valid that input and rcond have the same dtype?

Also also, what happens when the dtype is bfloat16 or half?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, I'm mistaken. It should not be that the dtype of input and rcond are the same.

Shouldn't it be that rcond must always be the "value type" of input? That is, if input is double rcond should be double, but if input is complex double then rcond should be double, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the torch check to accept only float, double, cfloat or cdouble types.

As for the rcond, I don't know actually how we should restrict its types. When only a float is given from Python interface it always gets converted to double in C++ and then we always create a scalar tensor of type double.
If directly a tensor is passed, then all types should be valid that allow multiplication with max_val tensor (that can be only of type float or double) and that allow comparison with a tensor of type float or double. So I guess any floating and integer types (so everything that is not complex) should be valid here for any allowed input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added restriction for rcond type not to accept complex types 4de9786
And added tests to show that it works for all other types 4ac1eaf

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your analysis makes sense. Restricting it to not be complex sounds good.

return at::empty(input_sizes, input.options());
}

Tensor rcond_ = rcond;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what's happening here, and add a comment elaborating

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not needed here and I removed it. b7a3eb3

// TODO: replace input.svd with linalg_svd
std::tie(U, S, V_conj) = input.svd();
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Tensor S_pseudoinv = at::where(S > rcond_ * max_val, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parens around rcond_ * max_val so people don't need to worry about the binding of the comparison operator

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the .to(input.dtype()) actually necessary here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, .to(input.dtype()) is necessary because S is real-valued and we need it to convert to complex type to be able to do the matmul. at::matmul, and at::where don't work with mixed real and complex types.
Same discussion: #45819 (comment)

@IvanYashchuk
Copy link
Collaborator Author

@mruberry I've updated this pull request, it should be in better shape now. I've replied to most of the comments above.


Computes the pseudo-inverse (also known as the Moore-Penrose inverse) of a matrix :attr:`input`,
or of each matrix in a batched :attr:`input`.
The pseudo-inverse is computed using singular value decomposition (see :func:`torch.linalg.svd`) by default.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately torch.linalg.svd doesn't exist yet. This can reference torch.svd for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can change that.
I just thought that we should start using torch.linalg even if the references are not working for now.

While we're not actually using the newer linear algebra functions internally we should still reference them here.
#48206 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is nice to dream... but we should probably keep all the references in the docs working ;)

torch/linalg/__init__.py Outdated Show resolved Hide resolved
[-1.1921e-07+0.0000e+00j, -2.3842e-07-2.9802e-07j,
1.0000e+00-1.7897e-07j]])

Non-default rcond example
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is tricky. Because we don't see the intermediate steps it's hard to understand the effect that rcond is having. The next example has the same issue.

What are your thoughts? I think we can leave out these examples for now since this PR seems to be almost ready to merge, otherwise, and maybe return to add them later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation says that the pseudo-inverse is calculated using SVD and it says that rcond determines which singular values should be set to zero. The first example demonstrates that using the default and some other value for rcond gives different results as expected. Just using the facts from the docs I think it's understandable that different rconds in general give different results. I think we should keep this example because I think examples should contain the code snippets demonstrating each variable. That's studying math to really understand the effect of rcond (topic of Tikhonov regularization) and definitely not needed here. We're on the level of "something in the input changes, something in the output should change as well".

The purpose of the second rcond example (Matrix-wise rcond example) is to show what shapes of rcond are acceptable. I agree that the second example doesn't look transparent and it might be difficult to understand, but that's broadcasting. Smaller arrays are extended to higher-dimensional arrays. We have rcond must be broadcastable to singular values of input. in the docs. I don't see how to help the user to understand this bit better.

How about keeping these examples as is (that's working and valid code) and return to it later to improve them?
I don't mind removing it now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry, so should we remove those rcond examples now and think about it later or keep it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's fine, if you like them let's keep them.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another awesome PR, @IvanYashchuk.

There are two very minor comments and one more significant comment about the examples using rcond. I would really like to merge this ASAP, so maybe we just remove the rcond examples for now and come back to add them later?

@IvanYashchuk
Copy link
Collaborator Author

@mruberry I think we are ready to merge this PR.

@mruberry
Copy link
Collaborator

@IvanYashchuk would you just rebase to resolve the merge conflicts and ping me?

@IvanYashchuk
Copy link
Collaborator Author

@mruberry done.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 9384d31.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: numpy Related to numpy support, and also numpy compatibility of our operators open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants