-
Couldn't load subscription status.
- Fork 25.7k
[MPS] Add lu_factor #99269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Add lu_factor #99269
Changes from all commits
f7c9441
659be58
b9af2b9
6c7b9a6
f966d1a
de4efd1
76884ab
d75cde1
0f97fcc
de705f3
0bda969
ce8d9cc
cad764a
006ea34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,15 @@ | |
| #include <ATen/ops/addr_native.h> | ||
| #include <ATen/ops/baddbmm_native.h> | ||
| #include <ATen/ops/bmm_native.h> | ||
| #include <ATen/ops/linalg_lu_factor_native.h> | ||
| #include <ATen/ops/linalg_solve_triangular_native.h> | ||
| #include <ATen/ops/mm_native.h> | ||
| #include <ATen/ops/stack.h> | ||
| #include <ATen/ops/triangular_solve_native.h> | ||
| #endif | ||
|
|
||
| #include <algorithm> | ||
|
|
||
| namespace at::native { | ||
| namespace mps { | ||
| namespace { | ||
|
|
@@ -127,6 +131,116 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) | |
|
|
||
| } // anonymous namespace | ||
|
|
||
| static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { | ||
| using namespace mps; | ||
|
|
||
| TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), | ||
| "linalg.lu_factor(): MPS doesn't support complex types."); | ||
| TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False."); | ||
|
|
||
| Tensor A_t = A; | ||
| uint64_t aRows = A_t.size(-2); | ||
| uint64_t aCols = A_t.size(-1); | ||
| uint64_t aElemSize = A_t.element_size(); | ||
| uint64_t numPivots = std::min(aRows, aCols); | ||
| std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2); | ||
| pivot_sizes.push_back(numPivots); | ||
| resize_output(pivots, pivot_sizes); | ||
|
|
||
| if (A_t.numel() == 0) { | ||
| return; | ||
| } | ||
|
|
||
| Tensor A_ = A_t.dim() > 3 ? A_t.flatten(0, -3) : A_t; | ||
qqaatw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| uint64_t batchSize = A_.dim() > 2 ? A_.size(0) : 1; | ||
| std::vector<Tensor> status_tensors; | ||
| std::vector<Tensor> pivots_list; | ||
|
|
||
| status_tensors.reserve(batchSize); | ||
| pivots_list.reserve(batchSize); | ||
| for (C10_UNUSED const auto i : c10::irange(batchSize)) { | ||
| status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt)); | ||
| } | ||
|
Comment on lines
+157
to
+165
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just a tensor with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know enough about MPS, but perhaps @kulinseth can comment on what's the best way to do this. It'd be good if For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qqaatw , you can use the Aliasing strategy : So start with using MPSNDArray to create the object. And then convert MPSNDArray tot MPSMatrix to be passed to LU Solve : There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, should it be initialized with With pivot matrix initialized with I also tried initializing with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kulinseth I can provide the code if you need repro. https://gist.github.com/qqaatw/3b3cb633c60fcd6abab3fc5f0e468b88#file-repro-mm |
||
|
|
||
| // Since the MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix, | ||
| // We copy LU from A as the new A. | ||
| resize_output(LU, A_.sizes()); | ||
| if (!LU.is_same(A_)) { | ||
| A_ = LU.copy_(A_); | ||
| } else { | ||
| A_ = LU; | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT(A_.is_contiguous()) | ||
|
|
||
| id<MTLBuffer> aBuffer = getMTLBufferStorage(A_); | ||
|
|
||
| MPSStream* mpsStream = getCurrentMPSStream(); | ||
| id<MTLDevice> device = MPSDevice::getInstance()->device(); | ||
|
|
||
| dispatch_sync_with_rethrow(mpsStream->queue(), ^() { | ||
| @autoreleasepool { | ||
| id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer(); | ||
| MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device | ||
| rows:aRows | ||
| columns:aCols] autorelease]; | ||
|
|
||
| MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows | ||
| columns:aCols | ||
| matrices:batchSize | ||
| rowBytes:aCols * aElemSize | ||
| matrixBytes:aRows * aCols * aElemSize | ||
| dataType:getMPSDataType(A_)]; | ||
| MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1 | ||
| columns:numPivots | ||
| matrices:1 | ||
| rowBytes:numPivots * sizeof(uint32_t) | ||
| matrixBytes:numPivots * sizeof(uint32_t) | ||
| dataType:MPSDataTypeUInt32]; | ||
|
|
||
| for (const auto i : c10::irange(batchSize)) { | ||
| const uint64_t aBatchOffset = i * aRows * aCols; | ||
| MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer | ||
| offset:(A_.storage_offset() + aBatchOffset) * aElemSize | ||
| descriptor:sourceMatrixDesc] autorelease]; | ||
| MPSMatrix* pivotIndices = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i]) | ||
| offset:0 | ||
| descriptor:pivotsMatrixDesc] autorelease]; | ||
| MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer | ||
| offset:(A_.storage_offset() + aBatchOffset) * aElemSize | ||
| descriptor:sourceMatrixDesc] autorelease]; | ||
| id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]); | ||
| [filter encodeToCommandBuffer:commandBuffer | ||
| sourceMatrix:sourceMatrix | ||
| resultMatrix:solutionMatrix | ||
| pivotIndices:pivotIndices | ||
| status:statusBuffer]; | ||
| } | ||
| } | ||
| }); | ||
| auto stacked_pivots = A_.dim() > 2 ? at::stack(pivots_list) : pivots_list[0]; | ||
| if (A_t.dim() > 3) { | ||
| resize_output(LU, A_t.sizes()); | ||
| pivots.copy_(stacked_pivots.view(pivot_sizes)); | ||
| } else { | ||
| pivots.copy_(stacked_pivots); | ||
| } | ||
| pivots += 1; // PyTorch's `pivots` is 1-index. | ||
|
|
||
| for (const auto i : c10::irange(status_tensors.size())) { | ||
| int status = status_tensors[i].item<int>(); | ||
| TORCH_CHECK( | ||
| status == 0, | ||
| "lu_factor(): LU factorization failure at the ", | ||
| i + 1, | ||
| " sample with status: ", | ||
| status, | ||
| ". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details."); | ||
| } | ||
| } | ||
|
|
||
| static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { | ||
| using namespace mps; | ||
| using CachedGraph = MPSBinaryCachedGraph; | ||
|
|
@@ -753,4 +867,16 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, | |
| result.copy_(out); | ||
| } | ||
|
|
||
| std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { | ||
| mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); | ||
| return std::tie(LU, pivots); | ||
| } | ||
|
|
||
| std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) { | ||
| Tensor LU = at::empty({0}, A.options()); | ||
| Tensor pivots = at::empty({0}, A.options().dtype(kInt)); | ||
| mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); | ||
| return std::make_tuple(std::move(LU), std::move(pivots)); | ||
| } | ||
|
|
||
| } // namespace at::native | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13797,10 +13797,16 @@ | |
| - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) | ||
| python_module: linalg | ||
| variants: function | ||
| dispatch: | ||
| CompositeImplicitAutograd: linalg_lu_factor | ||
| MPS: linalg_lu_factor_mps | ||
|
|
||
| - func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) | ||
| python_module: linalg | ||
| variants: function | ||
| dispatch: | ||
| CompositeImplicitAutograd: linalg_lu_factor_out | ||
| MPS: linalg_lu_factor_out_mps | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I cannot directly return an undefined Tensor |
||
|
|
||
| - func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) | ||
| python_module: linalg | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.