Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix Multiplication Performance and Slicing Issues with PyTorch MPS Backend #122123

Open
Jckwind opened this issue Mar 18, 2024 · 1 comment
Open
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Jckwind
Copy link

Jckwind commented Mar 18, 2024

馃悰 Describe the bug

The MPS backend of PyTorch has been experiencing a long-standing bug and performance issues related to matrix multiplication and tensor slicing. This issue has been acknowledged in previous GitHub issues #111634, #116769, and #122045.

The following examples demonstrate the runtime errors encountered:

Example 1:

A = torch.randn(32769, 4, device='mps')
B = A[1:25] @ torch.randn(4, 10, device='mps')

Error:

[MPSNDArrayDescriptor sliceDimension:withSubrange:] error: subRange.start (8192) is not less than length of dimension[1] (24)

Example 2:
Note: this example will fail when reverting the temporary fix implemented in PR #117549

A = torch.randn(1, 100000, device="mps")
B = torch.randn(10, 1, device="mps")
A = A[:, 16384:32769]
print(torch.mm(B, A))

Error:

error: subRange.start (24576) is not less than length of dimension[0] (16385)

Furthermore, the performance of matrix multiplication on the MPS backend is significantly slower compared to the CPU backend, as demonstrated by the following example:

import timeit

a_cpu = torch.rand(250, device='cpu')
b_cpu = torch.rand((250, 250), device='cpu')
a_mps = torch.rand(250, device='mps')
b_mps = torch.rand((250, 250), device='mps')

print('cpu', timeit.timeit(lambda: a_cpu @ b_cpu, number=100_000))
print('mps', timeit.timeit(lambda: a_mps @ b_mps, number=100_000))

Output:

cpu 1.95958621
mps 7.72093108

Previous attempts to address this issue, such as the fixes proposed in PR #117549 and PR #117319, have provided temporary solutions but remain incomplete.

Improving the performance and reliability of the MPS backend is crucial for PyTorch, as it is a beginner-friendly framework and many engineers work on Apple Silicon. Enabling efficient local development on MacBook GPUs is essential for both beginners and open-source research.

I'm committed to working on resolving these issues. I have been dedicating time to debugging and improving the MPS backend code, which has not received significant attention in the past two years. If I find a fix or performance improvement in the meantime, I will open a pull request to contribute to the project.

I came across a similar issue, involving a similar error message, which I solved in PR #121645, but have not been able to get my fix there to translate to this issue.

So far, I've found a couple points of interest / files to look at for culprints:

  1. aten/src/ATen/native/mps/operations/View.mm
  2. aten/src/ATen/native/mps/operations/LinearAlgebra.mm
  3. aten/src/ATen/native/mps/operations/Copy.mm
  4. aten/src/ATen/native/TensorShape.cpp
  5. aten/src/ATen/native/mps/operations/Distributions.mm

Some more PRs or issues that I've come across in my research of this issue, that could be related: #110120, #109557, #114838

P.S just a general note, as I said above I'm new to MPS, but just in general are we taking advantage of these MPS Kernels, like this one for Matrix Multiplication? If not, is it trivial work to refactor?

Versions

Collecting environment information...
PyTorch version: 2.3.0a0+git82bb063
Is debug build: True
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.3.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.26.4
Libc version: N/A

Python version: 3.9.18 (main, Sep 11 2023, 08:25:10) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.3.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.10.0
[pip3] torch==2.3.0a0+gitf37ab64
[conda] numpy 1.26.4 pypi_0 pypi
[conda] optree 0.10.0 pypi_0 pypi
[conda] torch 2.3.0a0+gitf37ab64 dev_0

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr

@Jckwind
Copy link
Author

Jckwind commented Mar 18, 2024

cc: @malfet @kulinseth

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework labels Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants