From f7c9441bf2746f2e07c88e566b172577ecf17d51 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Mon, 17 Apr 2023 01:36:51 +0800 Subject: [PATCH 1/3] [MPS] Add lu_solve, lu_factor [ghstack-poisoned] --- .../native/mps/operations/LinearAlgebra.mm | 353 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 11 + test/test_mps.py | 4 +- .../_internal/opinfo/definitions/linalg.py | 2 +- 4 files changed, 367 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 37625f4ebef8..8366b2901503 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -14,8 +14,12 @@ #include #include #include +#include +#include #endif +#include + namespace at::native { namespace mps { /* @@ -91,6 +95,214 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE }; +std::vector linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { + using namespace mps; + + Tensor A_t = A; + uint64_t aRows = A_t.size(-2); + uint64_t aCols = A_t.size(-1); + uint64_t aElemSize = A_t.element_size(); + uint64_t numPivots = std::min(aRows, aCols); + uint64_t batchSize = A_t.sizes().size() > 2 ? A_t.size(0) : 1; + resize_output(LU, A_t.sizes()); + + auto pivot_sizes = A_t.sizes().vec(); + pivot_sizes.pop_back(); + pivot_sizes[-1] = numPivots; + resize_output(pivots, pivot_sizes); + + std::vector status_tensors; + std::vector pivots_list; + + status_tensors.reserve(batchSize); + pivots_list.reserve(batchSize); + for (const auto i : c10::irange(batchSize)) { + status_tensors.push_back(at::zeros(1, + kInt, + c10::nullopt, + kMPS, + c10::nullopt)); + pivots_list.push_back(at::zeros(numPivots, + kInt, + c10::nullopt, + kMPS, + c10::nullopt)); + } + + std::cout << "pivots size" << pivot_sizes << "acols" << aCols << "aRows" << aRows << "aelesize" << aElemSize << std::endl; + if (A_t.numel() == 0 || LU.numel() == 0) { + LU.zero_(); + return status_tensors; + } + + Tensor A_ = A_t; + if (!A_t.is_contiguous()) { + A_ = A_t.clone(at::MemoryFormat::Contiguous); + } + // Since the LUDecomposition functions in-place if the result matrix completely aliases the source matrix, + // We copy LU from A as the new A. + A_ = LU.copy_(A_); + + id aBuffer = getMTLBufferStorage(A_); + //id pivotsBuffer = getMTLBufferStorage(pivots); + + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id commandBuffer = mpsStream->commandBuffer(); + MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device + rows: aRows + columns: aCols] autorelease]; + + MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:aCols + matrices:batchSize + rowBytes:aCols * aElemSize + matrixBytes:aRows * aCols * aElemSize + dataType:getMPSDataType(A_)]; + MPSMatrixDescriptor* pivotsMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:1 + columns:numPivots + matrices:batchSize + rowBytes:numPivots * sizeof(uint32_t) + matrixBytes:numPivots * sizeof(uint32_t) + dataType:MPSDataTypeUInt32]; + + for (const auto i : c10::irange(batchSize)) { + const uint64_t aBatchOffset = i * aRows * aCols; + MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer + offset:(A_.storage_offset() + aBatchOffset) * aElemSize + descriptor:sourceMatrixDesc] autorelease]; + MPSMatrix* pivotIndices = + [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i]) + offset:0 + descriptor:pivotsMatrixDesc] autorelease]; + MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer + offset:(A_.storage_offset() + aBatchOffset) * aElemSize + descriptor:sourceMatrixDesc] autorelease]; + id statusBuffer = getMTLBufferStorage(status_tensors[i]); + [filter encodeToCommandBuffer:commandBuffer + sourceMatrix:sourceMatrix + resultMatrix:solutionMatrix + pivotIndices:pivotIndices + status:statusBuffer + ]; + } + mpsStream->commit(true); + } + }); + auto stacked_pivots = A_.dim() > 2 ? at::stack(pivots_list) : pivots_list[0]; + pivots.copy_(stacked_pivots); + pivots += 1; // PyTorch's `pivots` is 1-index. + + return status_tensors; +} + +void linalg_lu_solve_out_mps_impl(const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + bool left, + bool adjoint, + const Tensor& out){ + using namespace mps; + + if (LU.numel() == 0 || B.numel() == 0 || out.numel() == 0) { + out.zero_(); + return; + } + + Tensor A_ = LU.is_contiguous() ? LU : LU.clone(at::MemoryFormat::Contiguous); + Tensor B_ = B.is_contiguous() ? B : B.clone(at::MemoryFormat::Contiguous); + Tensor out_ = out.is_contiguous() ? out : out.clone(at::MemoryFormat::Contiguous); + Tensor pivots_ = pivots.is_contiguous() ? pivots : pivots.clone(at::MemoryFormat::Contiguous); + + std::cout << "Out:" << out_ << std::endl; + std::cout << "A:" << A_ << std::endl; + std::cout << "B:" << B_ << std::endl; + id aBuffer = getMTLBufferStorage(A_); + id bBuffer = getMTLBufferStorage(B_); + id outBuffer = getMTLBufferStorage(out_); + id pivotIndicesBuffer = getMTLBufferStorage(pivots_); + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id commandBuffer = mpsStream->commandBuffer(); + uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1; + uint64_t aRows = A_.size(-2); + uint64_t bRows = B_.size(-2); + uint64_t aCols = A_.size(-1); + uint64_t bCols = B_.size(-1); + uint64_t aElemSize = A_.element_size(); + uint64_t bElemSize = B_.element_size(); + uint64_t numPivots = pivots_.size(-1); + + MPSMatrixSolveLU* filter = [[[MPSMatrixSolveLU alloc] initWithDevice:device + transpose:adjoint ? true : false + order:left ? bRows : bCols + numberOfRightHandSides:left ? bCols : bRows] autorelease]; + + MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:aCols + matrices:batchSize + rowBytes:aCols * aElemSize + matrixBytes:aRows * aCols * aElemSize + dataType:getMPSDataType(A_)]; + MPSMatrixDescriptor* rightHandSideMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:bRows + columns:bCols + matrices:batchSize + rowBytes:bCols * bElemSize + matrixBytes:bRows * bCols * bElemSize + dataType:getMPSDataType(B_)]; + MPSMatrixDescriptor* pivotsMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:1 + columns:numPivots + matrices:batchSize + rowBytes:numPivots * sizeof(uint32_t) + matrixBytes:numPivots * sizeof(uint32_t) + dataType:MPSDataTypeUInt32]; + for (const auto i : c10::irange(batchSize)) { + const uint64_t aBatchOffset = i * aRows * aCols; + const uint64_t bBatchOffset = i * bRows * bCols; + const uint64_t pivotsBatchOffset = i * numPivots; + std::cout << "aBatchOffset " << aBatchOffset << "bBatchOffset " << bBatchOffset << "pivots " << pivotsBatchOffset << std::endl; + + MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer + offset:(A_.storage_offset() + aBatchOffset) * aElemSize + descriptor:sourceMatrixDesc] autorelease]; + MPSMatrix* rightHandSideMatrix = + [[[MPSMatrix alloc] initWithBuffer:bBuffer + offset:(B_.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix* pivotIndices = + [[[MPSMatrix alloc] initWithBuffer:pivotIndicesBuffer + offset:(pivots_.storage_offset() + pivotsBatchOffset) * sizeof(uint32_t) + descriptor:pivotsMatrixDesc] autorelease]; + MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer + offset:(out_.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + + [filter encodeToCommandBuffer:commandBuffer + sourceMatrix:sourceMatrix + rightHandSideMatrix:rightHandSideMatrix + pivotIndices:pivotIndices + solutionMatrix:solutionMatrix]; + } + mpsStream->commit(true); + } + }); + //if (!out.is_contiguous()) { + out.copy_(out_); + std::cout << "Post A:" << A_ << std::endl; + std::cout << "Post B:" << B_ << std::endl; + std::cout << "Post out:" << out_ << std::endl; + //} +} + Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; using CachedGraph = MPSBinaryCachedGraph; @@ -626,6 +838,96 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, }); return out; } +/* +Tensor& linalg_solve_out_mps_impl(const Tensor& A, + const Tensor& B, + bool left, + const Tensor& result) { + using namespace mps; + + checkInputsSolver(A, B, left, "linalg.solve"); + Tensor A_t, B_t; + std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, nullptr); + at::native::resize_output(out, B_t.sizes()); + + if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) { + out.zero_(); + return out; + } + + Tensor A_ = A_t; + Tensor B_ = B_t; + if (!A_t.is_contiguous()) { + A_ = A_t.clone(at::MemoryFormat::Contiguous); + } + if (!B_t.is_contiguous()) { + B_ = B_t.clone(at::MemoryFormat::Contiguous); + } + Tensor + id aBuffer = getMTLBufferStorage(A_); + id bBuffer = getMTLBufferStorage(B_); + id outBuffer = getMTLBufferStorage(out); + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + id commandBuffer = mpsStream->commandBuffer(); + uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1; + uint64_t aRows = A_.size(-2); + uint64_t bRows = B_.size(-2); + uint64_t aCols = A_.size(-1); + uint64_t bCols = B_.size(-1); + uint64_t aElemSize = A_.element_size(); + uint64_t bElemSize = B_.element_size(); + + MPSMatrixSolveLU* filter = [[[MPSMatrixSolveLU alloc] initWithDevice:device + transpose:false + order:left ? bRows : bCols + numberOfRightHandSides:left ? bCols : bRows] autorelease]; + + MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:aCols + matrices:batchSize + rowBytes:aCols * aElemSize + matrixBytes:aRows * aCols * aElemSize + dataType:getMPSDataType(A_)]; + MPSMatrixDescriptor* rightHandSideMatrixDesc = + [MPSMatrixDescriptor matrixDescriptorWithRows:bRows + columns:bCols + matrices:batchSize + rowBytes:bCols * bElemSize + matrixBytes:bRows * bCols * bElemSize + dataType:getMPSDataType(B_)]; + for (const auto i : c10::irange(batchSize)) { + const uint64_t aBatchOffset = i * aRows * aCols; + const uint64_t bBatchOffset = i * bRows * bCols; + MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer + offset:(A_t.storage_offset() + aBatchOffset) * aElemSize + descriptor:sourceMatrixDesc] autorelease]; + MPSMatrix* rightHandSideMatrix = + [[[MPSMatrix alloc] initWithBuffer:bBuffer + offset:(B_t.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix* pivotIndices = + [[[MPSMatrix alloc] initWithBuffer:pivotIndicesBuffer + offset:(B_t.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer + offset:(out.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + + [filter encodeToCommandBuffer:commandBuffer + sourceMatrix:sourceMatrix + rightHandSideMatrix:rightHandSideMatrix + pivotIndices: + solutionMatrix:solutionMatrix]; + } + mpsStream->commit(true); + } + }); + return out; +}*/ } // namespace mps @@ -853,5 +1155,56 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, result.resize_(out.sizes()); result.copy_(out); } +/* +TORCH_IMPL_FUNC(linalg_lu_solve_out_mps)(const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + bool left, + bool adjoint, + const Tensor& result) { + +}*/ + +std::tuple linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { + auto status_tensors = mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); + for (const auto i : c10::irange(status_tensors.size())) { + TORCH_CHECK(status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); + } + return std::tie(LU, pivots); +} +std::tuple linalg_lu_factor_mps(const Tensor& A, bool pivot) { + Tensor LU = at::empty({0}, A.options()); + Tensor pivots = at::empty({0}, A.options().dtype(kInt)); + auto status_tensors = mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); + for (const auto i : c10::irange(status_tensors.size())) { + TORCH_CHECK(status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); + } + return std::make_tuple(std::move(LU), std::move(pivots)); +} + +TORCH_IMPL_FUNC(linalg_lu_solve_out_mps)(const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + bool left, + bool adjoint, + const Tensor& result) { + // Trivial case + if (result.numel() == 0) { + return; + } + + // Solve A^H X = B^H. Then we return X^H + if (!left) { + adjoint = !adjoint; + result.transpose_(-2, -1); + } + + // Copy B (or B^H) into result + //if (!result.is_same(B)) { + // result.copy_(left ? B : B.mH()); + //} + + mps::linalg_lu_solve_out_mps_impl(LU, pivots, B, left, adjoint, result); +} } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f60165ae4d68..4f7b419211f0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13075,10 +13075,16 @@ - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) python_module: linalg variants: function + dispatch: + CompositeImplicitAutograd: linalg_lu_factor + MPS: linalg_lu_factor_mps - func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) python_module: linalg variants: function + dispatch: + CompositeImplicitAutograd: linalg_lu_factor_out + MPS: linalg_lu_factor_out_mps - func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) python_module: linalg @@ -13117,6 +13123,7 @@ structured: True dispatch: CPU, CUDA: linalg_lu_solve_out + MPS: linalg_lu_solve_out_mps # linalg.det - func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) @@ -13455,9 +13462,13 @@ - func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor python_module: linalg + #dispatch: + # MPS: linalg_solve_mps - func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) python_module: linalg + #dispatch: + # MPS: linalg_solve_out_mps - func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor python_module: linalg diff --git a/test/test_mps.py b/test/test_mps.py index 0ad5acab5851..105a684ee9d6 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -450,7 +450,7 @@ def mps_ops_modifier(ops): 'linalg.lu': None, 'linalg.lu_factor': None, 'linalg.lu_factor_ex': None, - 'linalg.lu_solve': None, + #'linalg.lu_solve': None, 'linalg.matrix_norm': [torch.float32], 'linalg.norm': [torch.float32], 'linalg.normsubgradients_at_zero': [torch.float32], @@ -465,7 +465,7 @@ def mps_ops_modifier(ops): 'logcumsumexp': None, 'logdet': None, 'lu': None, - 'lu_solve': None, + #'lu_solve': None, 'lu_unpack': None, 'masked.cumprod': None, 'masked.median': None, diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index b8ace02b4a35..27e3b1cc2393 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -279,7 +279,7 @@ def clone(X, requires_grad): is_linalg_lu_solve = op_info.name == "linalg.lu_solve" - batches = ((), (0,), (2,)) + batches = ((),) #(0,), (2,)) ns = (3, 1, 0) nrhs = (4, 1, 0) From 659be58e4885937f929df0fbb88ba00064fd0330 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Mon, 17 Apr 2023 12:09:19 +0800 Subject: [PATCH 2/3] Update on "[MPS] Add lu_solve, lu_factor" [ghstack-poisoned] --- .../native/mps/operations/LinearAlgebra.mm | 109 ++++++++---------- 1 file changed, 49 insertions(+), 60 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 8366b2901503..2c4a46edb15b 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -12,10 +12,10 @@ #include #include #include -#include -#include #include +#include #include +#include #endif #include @@ -116,20 +116,13 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, status_tensors.reserve(batchSize); pivots_list.reserve(batchSize); - for (const auto i : c10::irange(batchSize)) { - status_tensors.push_back(at::zeros(1, - kInt, - c10::nullopt, - kMPS, - c10::nullopt)); - pivots_list.push_back(at::zeros(numPivots, - kInt, - c10::nullopt, - kMPS, - c10::nullopt)); + for (C10_UNUSED const auto i : c10::irange(batchSize)) { + status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt)); + pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt)); } - std::cout << "pivots size" << pivot_sizes << "acols" << aCols << "aRows" << aRows << "aelesize" << aElemSize << std::endl; + std::cout << "pivots size" << pivot_sizes << "acols" << aCols << "aRows" << aRows << "aelesize" << aElemSize + << std::endl; if (A_t.numel() == 0 || LU.numel() == 0) { LU.zero_(); return status_tensors; @@ -144,8 +137,8 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, A_ = LU.copy_(A_); id aBuffer = getMTLBufferStorage(A_); - //id pivotsBuffer = getMTLBufferStorage(pivots); - + // id pivotsBuffer = getMTLBufferStorage(pivots); + MPSStream* mpsStream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); @@ -153,8 +146,8 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, @autoreleasepool { id commandBuffer = mpsStream->commandBuffer(); MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device - rows: aRows - columns: aCols] autorelease]; + rows:aRows + columns:aCols] autorelease]; MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows columns:aCols @@ -162,23 +155,21 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, rowBytes:aCols * aElemSize matrixBytes:aRows * aCols * aElemSize dataType:getMPSDataType(A_)]; - MPSMatrixDescriptor* pivotsMatrixDesc = - [MPSMatrixDescriptor matrixDescriptorWithRows:1 - columns:numPivots - matrices:batchSize - rowBytes:numPivots * sizeof(uint32_t) - matrixBytes:numPivots * sizeof(uint32_t) - dataType:MPSDataTypeUInt32]; - + MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1 + columns:numPivots + matrices:batchSize + rowBytes:numPivots * sizeof(uint32_t) + matrixBytes:numPivots * sizeof(uint32_t) + dataType:MPSDataTypeUInt32]; + for (const auto i : c10::irange(batchSize)) { const uint64_t aBatchOffset = i * aRows * aCols; MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer offset:(A_.storage_offset() + aBatchOffset) * aElemSize descriptor:sourceMatrixDesc] autorelease]; - MPSMatrix* pivotIndices = - [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i]) - offset:0 - descriptor:pivotsMatrixDesc] autorelease]; + MPSMatrix* pivotIndices = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i]) + offset:0 + descriptor:pivotsMatrixDesc] autorelease]; MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer offset:(A_.storage_offset() + aBatchOffset) * aElemSize descriptor:sourceMatrixDesc] autorelease]; @@ -187,8 +178,7 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, sourceMatrix:sourceMatrix resultMatrix:solutionMatrix pivotIndices:pivotIndices - status:statusBuffer - ]; + status:statusBuffer]; } mpsStream->commit(true); } @@ -201,11 +191,11 @@ void prepare_matrices_for_broadcasting(const Tensor* bias, } void linalg_lu_solve_out_mps_impl(const Tensor& LU, - const Tensor& pivots, - const Tensor& B, - bool left, - bool adjoint, - const Tensor& out){ + const Tensor& pivots, + const Tensor& B, + bool left, + bool adjoint, + const Tensor& out) { using namespace mps; if (LU.numel() == 0 || B.numel() == 0 || out.numel() == 0) { @@ -258,18 +248,18 @@ void linalg_lu_solve_out_mps_impl(const Tensor& LU, rowBytes:bCols * bElemSize matrixBytes:bRows * bCols * bElemSize dataType:getMPSDataType(B_)]; - MPSMatrixDescriptor* pivotsMatrixDesc = - [MPSMatrixDescriptor matrixDescriptorWithRows:1 - columns:numPivots - matrices:batchSize - rowBytes:numPivots * sizeof(uint32_t) - matrixBytes:numPivots * sizeof(uint32_t) - dataType:MPSDataTypeUInt32]; + MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1 + columns:numPivots + matrices:batchSize + rowBytes:numPivots * sizeof(uint32_t) + matrixBytes:numPivots * sizeof(uint32_t) + dataType:MPSDataTypeUInt32]; for (const auto i : c10::irange(batchSize)) { const uint64_t aBatchOffset = i * aRows * aCols; const uint64_t bBatchOffset = i * bRows * bCols; const uint64_t pivotsBatchOffset = i * numPivots; - std::cout << "aBatchOffset " << aBatchOffset << "bBatchOffset " << bBatchOffset << "pivots " << pivotsBatchOffset << std::endl; + std::cout << "aBatchOffset " << aBatchOffset << "bBatchOffset " << bBatchOffset << "pivots " + << pivotsBatchOffset << std::endl; MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer offset:(A_.storage_offset() + aBatchOffset) * aElemSize @@ -282,9 +272,10 @@ void linalg_lu_solve_out_mps_impl(const Tensor& LU, [[[MPSMatrix alloc] initWithBuffer:pivotIndicesBuffer offset:(pivots_.storage_offset() + pivotsBatchOffset) * sizeof(uint32_t) descriptor:pivotsMatrixDesc] autorelease]; - MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer - offset:(out_.storage_offset() + bBatchOffset) * bElemSize - descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix* solutionMatrix = + [[[MPSMatrix alloc] initWithBuffer:outBuffer + offset:(out_.storage_offset() + bBatchOffset) * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; [filter encodeToCommandBuffer:commandBuffer sourceMatrix:sourceMatrix @@ -295,7 +286,7 @@ void linalg_lu_solve_out_mps_impl(const Tensor& LU, mpsStream->commit(true); } }); - //if (!out.is_contiguous()) { + // if (!out.is_contiguous()) { out.copy_(out_); std::cout << "Post A:" << A_ << std::endl; std::cout << "Post B:" << B_ << std::endl; @@ -863,7 +854,7 @@ void linalg_lu_solve_out_mps_impl(const Tensor& LU, if (!B_t.is_contiguous()) { B_ = B_t.clone(at::MemoryFormat::Contiguous); } - Tensor + Tensor id aBuffer = getMTLBufferStorage(A_); id bBuffer = getMTLBufferStorage(B_); id outBuffer = getMTLBufferStorage(out); @@ -1168,7 +1159,8 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, std::tuple linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) { auto status_tensors = mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); for (const auto i : c10::irange(status_tensors.size())) { - TORCH_CHECK(status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); + TORCH_CHECK( + status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); } return std::tie(LU, pivots); } @@ -1177,17 +1169,14 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, Tensor pivots = at::empty({0}, A.options().dtype(kInt)); auto status_tensors = mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots); for (const auto i : c10::irange(status_tensors.size())) { - TORCH_CHECK(status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); + TORCH_CHECK( + status_tensors[i].item() == 0, "lu_factor(): LU factorization failure:", status_tensors[i].item()); } return std::make_tuple(std::move(LU), std::move(pivots)); } -TORCH_IMPL_FUNC(linalg_lu_solve_out_mps)(const Tensor& LU, - const Tensor& pivots, - const Tensor& B, - bool left, - bool adjoint, - const Tensor& result) { +TORCH_IMPL_FUNC(linalg_lu_solve_out_mps) +(const Tensor& LU, const Tensor& pivots, const Tensor& B, bool left, bool adjoint, const Tensor& result) { // Trivial case if (result.numel() == 0) { return; @@ -1198,9 +1187,9 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, adjoint = !adjoint; result.transpose_(-2, -1); } - + // Copy B (or B^H) into result - //if (!result.is_same(B)) { + // if (!result.is_same(B)) { // result.copy_(left ? B : B.mH()); //} From 006ea34c2c937547590a47ace98797690ae27533 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 19 Jun 2024 21:07:23 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- aten/src/ATen/native/mps/operations/LinearAlgebra.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index c49b55aab225..dfe0605add92 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -151,7 +151,7 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L return; } - Tensor A_ = A_t.dim() > 3 ? A_t.view({-1, A_t.size(-2), A_t.size(-1)}) : A_t; + Tensor A_ = A_t.dim() > 3 ? A_t.flatten(0, -3) : A_t; uint64_t batchSize = A_.dim() > 2 ? A_.size(0) : 1; std::vector status_tensors;