-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[MPS] Add lu_factor #99269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Add lu_factor #99269
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99269
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 006ea34 with merge base 6e43897 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments.
I'll let the MPS guys review the main implementation though.
| variants: function | ||
| dispatch: | ||
| CompositeImplicitAutograd: linalg_lu_factor_out | ||
| MPS: linalg_lu_factor_out_mps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you implement linalg_lu_factor_ex rather than this one? That way you wouldn't need to add any new backward rule, and all the other goodies that are implemented for linalg.lu_factor will also extend to MPS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I implement linalg_lu_factor is because linalg_lu_factor_ex has to return an info tensor, which is not applicable to MPS. (the info tensor is computed by LAPACK on cpu, for example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I cannot directly return an undefined Tensor Tensor() as the info since there are some post check logics to the info tensor.
| std::vector<Tensor> status_tensors; | ||
| std::vector<Tensor> pivots_list; | ||
|
|
||
| status_tensors.reserve(batchSize); | ||
| pivots_list.reserve(batchSize); | ||
| for (C10_UNUSED const auto i : c10::irange(batchSize)) { | ||
| status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just a tensor with batchSize elements in the case of status_tensors? Same same for pivots_list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For pivots_list:
For some reason I don't know as the MPS kernel is closed source, probably because MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix per the docs, if we use the same MTLBuffer (the underlying storage of MPS tensor) with an offset for each matrix pivot, the resulting pivot values will be incorrect.
For status_tensors:
For each system, the kernel requires status being an MTLBuffer input to be encoded, which doesn't provide an option for specifying an offset to the buffer. Thus, the way I could come up with was splitting status to multiple tensors, each of which has its MTLBuffer. Maybe there is a better approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know enough about MPS, but perhaps @kulinseth can comment on what's the best way to do this. It'd be good if pivots could be a Tensor, that way we wouldn't need to first create a vector and then copy it into the output tensor.
For status_tensors, I don't know why can't we simply return the status tensors, same as we do for the pivots. This would allow us to implement the _ex variant, which is the expected way of implementing this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qqaatw , you can use the Aliasing strategy : MPSAliasingStrategyShallNotAlias and provide an Offset using the arrayView.
So start with using MPSNDArray to create the object.
arrayView = [ndArray arrayViewWithCommandBuffer:commandBuffer
descriptor:desc
aliasing:MPSAliasingStrategyShallNotAlias];
And then convert MPSNDArray tot MPSMatrix to be passed to LU Solve :
[[MPSMatrix alloc] initWithBuffer: [ndArray buffer]
offset: offset
descriptor: matDesc];
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, should it be initialized with [ndArray buffer] or [arrayView buffer]? I guess the latter is what you were suggesting.
With pivot matrix initialized with [arrayView buffer] that uses MPSAliasingStrategyShallNotAlias, the pivots output remains the same as the pivots before LU decomposition, which looks like the array view is not writable with this strategy? On the other hand, if I specify MPSAliasingStrategyShallAlias, the output is correct with unbatched inputs.
I also tried initializing with [ndArray buffer], the outputs were incorrect if the inputs were batched.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kulinseth I can provide the code if you need repro.
https://gist.github.com/qqaatw/3b3cb633c60fcd6abab3fc5f0e468b88#file-repro-mm
[ghstack-poisoned]
| void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { | ||
| using namespace mps; | ||
|
|
||
| TORCH_CHECK(pivot, "linalg_lu_factor(): MPS doesn't allow pivot == False."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. linalg.lu_factor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with comments earlier @lezcano ..
| std::vector<Tensor> status_tensors; | ||
| std::vector<Tensor> pivots_list; | ||
|
|
||
| status_tensors.reserve(batchSize); | ||
| pivots_list.reserve(batchSize); | ||
| for (C10_UNUSED const auto i : c10::irange(batchSize)) { | ||
| status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qqaatw , you can use the Aliasing strategy : MPSAliasingStrategyShallNotAlias and provide an Offset using the arrayView.
So start with using MPSNDArray to create the object.
arrayView = [ndArray arrayViewWithCommandBuffer:commandBuffer
descriptor:desc
aliasing:MPSAliasingStrategyShallNotAlias];
And then convert MPSNDArray tot MPSMatrix to be passed to LU Solve :
[[MPSMatrix alloc] initWithBuffer: [ndArray buffer]
offset: offset
descriptor: matDesc];
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a a minor correctness issue. That being said, the PR looks mostly right. @kulinseth already approved, so feel free to merge after either fixing the issue properly or simply reverting to the previous flatten (not entirely efficient, but at least correct).
| return; | ||
| } | ||
|
|
||
| Tensor A_ = A_t.dim() > 3 ? A_t.view({-1, A_t.size(-2), A_t.size(-1)}) : A_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will fail if the view cannot be performed. My point is that you do not need to iterate the strided dimension in contiguous order, nor you need it to be contiguous. You just need to iterate over every matrix inside it, so you never need to copy really.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 2, 5, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -ic |
|
|
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 2, 5, linux.g5.4xlarge.nvidia.gpu), pull / linux-jammy-py3.8-gcc11 / test (distributed, 2, 2, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…u_factor (#165871) Fixes #165870. Follow up from #165254. This PR [a] removes the MPS specific version of `lu_factor` in favor of the version in BatchedLinearAlgebra.cpp which uses `lu_factor_ex`, and [b] updates `lu_factor_ex` error codes to match expectations. When `lu_factor` was first implemented for MPS (#99269), it bypassed the implementation in BatchedLinearAlgebra.cpp since we did not have `lu_factor_ex`. Since #144651 implements `lu_factor_ex`, we can now remove the MPS specific wrapper. Pull Request resolved: #165871 Approved by: https://github.com/kulinseth, https://github.com/albanD
Stack from ghstack (oldest at bottom):
🤖 Generated by Copilot at d75cde1
Added MPS support and autograd formulas for LU factorization of tensors. Implemented the
linalg_lu_factorandlinalg_lu_factor.outfunctions for the MPS backend inLinearAlgebra.mmand added tests intest_mps.py. Added the corresponding dispatch entries innative_functions.yamland the backward and forward formulas inderivatives.yaml.