Skip to content
126 changes: 126 additions & 0 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
#include <ATen/ops/addr_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/triangular_solve_native.h>
#endif

#include <algorithm>

namespace at::native {
namespace mps {
namespace {
Expand Down Expand Up @@ -127,6 +131,116 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)

} // anonymous namespace

static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
using namespace mps;

TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");

Tensor A_t = A;
uint64_t aRows = A_t.size(-2);
uint64_t aCols = A_t.size(-1);
uint64_t aElemSize = A_t.element_size();
uint64_t numPivots = std::min(aRows, aCols);
std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2);
pivot_sizes.push_back(numPivots);
resize_output(pivots, pivot_sizes);

if (A_t.numel() == 0) {
return;
}

Tensor A_ = A_t.dim() > 3 ? A_t.flatten(0, -3) : A_t;

uint64_t batchSize = A_.dim() > 2 ? A_.size(0) : 1;
std::vector<Tensor> status_tensors;
std::vector<Tensor> pivots_list;

status_tensors.reserve(batchSize);
pivots_list.reserve(batchSize);
for (C10_UNUSED const auto i : c10::irange(batchSize)) {
status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt));
pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt));
}
Comment on lines +157 to +165
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just a tensor with batchSize elements in the case of status_tensors? Same same for pivots_list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pivots_list:
For some reason I don't know as the MPS kernel is closed source, probably because MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix per the docs, if we use the same MTLBuffer (the underlying storage of MPS tensor) with an offset for each matrix pivot, the resulting pivot values will be incorrect.

For status_tensors:
For each system, the kernel requires status being an MTLBuffer input to be encoded, which doesn't provide an option for specifying an offset to the buffer. Thus, the way I could come up with was splitting status to multiple tensors, each of which has its MTLBuffer. Maybe there is a better approach.

Copy link
Collaborator

@lezcano lezcano May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know enough about MPS, but perhaps @kulinseth can comment on what's the best way to do this. It'd be good if pivots could be a Tensor, that way we wouldn't need to first create a vector and then copy it into the output tensor.

For status_tensors, I don't know why can't we simply return the status tensors, same as we do for the pivots. This would allow us to implement the _ex variant, which is the expected way of implementing this function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qqaatw , you can use the Aliasing strategy : MPSAliasingStrategyShallNotAlias and provide an Offset using the arrayView.

So start with using MPSNDArray to create the object.

  arrayView = [ndArray arrayViewWithCommandBuffer:commandBuffer
                                                           descriptor:desc
                                                             aliasing:MPSAliasingStrategyShallNotAlias];

And then convert MPSNDArray tot MPSMatrix to be passed to LU Solve :

[[MPSMatrix alloc] initWithBuffer: [ndArray buffer]
                                      offset: offset
                                  descriptor: matDesc];
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, should it be initialized with [ndArray buffer] or [arrayView buffer]? I guess the latter is what you were suggesting.

With pivot matrix initialized with [arrayView buffer] that uses MPSAliasingStrategyShallNotAlias, the pivots output remains the same as the pivots before LU decomposition, which looks like the array view is not writable with this strategy? On the other hand, if I specify MPSAliasingStrategyShallAlias, the output is correct with unbatched inputs.

I also tried initializing with [ndArray buffer], the outputs were incorrect if the inputs were batched.

Copy link
Collaborator Author

@qqaatw qqaatw May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// Since the MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix,
// We copy LU from A as the new A.
resize_output(LU, A_.sizes());
if (!LU.is_same(A_)) {
A_ = LU.copy_(A_);
} else {
A_ = LU;
}

TORCH_INTERNAL_ASSERT(A_.is_contiguous())

id<MTLBuffer> aBuffer = getMTLBufferStorage(A_);

MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();

dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device
rows:aRows
columns:aCols] autorelease];

MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:aCols
matrices:batchSize
rowBytes:aCols * aElemSize
matrixBytes:aRows * aCols * aElemSize
dataType:getMPSDataType(A_)];
MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1
columns:numPivots
matrices:1
rowBytes:numPivots * sizeof(uint32_t)
matrixBytes:numPivots * sizeof(uint32_t)
dataType:MPSDataTypeUInt32];

for (const auto i : c10::irange(batchSize)) {
const uint64_t aBatchOffset = i * aRows * aCols;
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
descriptor:sourceMatrixDesc] autorelease];
MPSMatrix* pivotIndices = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i])
offset:0
descriptor:pivotsMatrixDesc] autorelease];
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
descriptor:sourceMatrixDesc] autorelease];
id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]);
[filter encodeToCommandBuffer:commandBuffer
sourceMatrix:sourceMatrix
resultMatrix:solutionMatrix
pivotIndices:pivotIndices
status:statusBuffer];
}
}
});
auto stacked_pivots = A_.dim() > 2 ? at::stack(pivots_list) : pivots_list[0];
if (A_t.dim() > 3) {
resize_output(LU, A_t.sizes());
pivots.copy_(stacked_pivots.view(pivot_sizes));
} else {
pivots.copy_(stacked_pivots);
}
pivots += 1; // PyTorch's `pivots` is 1-index.

for (const auto i : c10::irange(status_tensors.size())) {
int status = status_tensors[i].item<int>();
TORCH_CHECK(
status == 0,
"lu_factor(): LU factorization failure at the ",
i + 1,
" sample with status: ",
status,
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
}
}

static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
Expand Down Expand Up @@ -753,4 +867,16 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper,
result.copy_(out);
}

std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
return std::tie(LU, pivots);
}

std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
Tensor LU = at::empty({0}, A.options());
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
return std::make_tuple(std::move(LU), std::move(pivots));
}

} // namespace at::native
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13797,10 +13797,16 @@
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor
MPS: linalg_lu_factor_mps

- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor_out
MPS: linalg_lu_factor_out_mps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you implement linalg_lu_factor_ex rather than this one? That way you wouldn't need to add any new backward rule, and all the other goodies that are implemented for linalg.lu_factor will also extend to MPS.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I implement linalg_lu_factor is because linalg_lu_factor_ex has to return an info tensor, which is not applicable to MPS. (the info tensor is computed by LAPACK on cpu, for example).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I cannot directly return an undefined Tensor Tensor() as the info since there are some post check logics to the info tensor.


- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
python_module: linalg
Expand Down
2 changes: 1 addition & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def mps_ops_grad_modifier(ops):
'cdist': [torch.float32],
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'aminmax': [torch.float32, torch.float16],
'polar': [torch.float32],

Expand Down Expand Up @@ -731,7 +732,6 @@ def mps_ops_modifier(ops):
'linalg.lstsq': None,
'linalg.lstsqgrad_oriented': None,
'linalg.lu': None,
'linalg.lu_factor': None,
'linalg.lu_factor_ex': None,
'linalg.lu_solve': None,
'linalg.matrix_norm': [torch.float32],
Expand Down
5 changes: 5 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,11 @@
LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot)
output_differentiability: [True, False, False]

- name: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
A: lu_factor_ex_backward(grad, LU, pivots, pivot)
LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot)
output_differentiability: [True, False]

- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)
A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot)
L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot))
Expand Down