-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Enable batch matmul for result sizes > 2**32 the tensor can be split along batch axis #133430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133430
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Need some guidance on how do we want to test this: In order to reach the original failure a 96GB configuration is required. With a 64GB config we will hit another error when trying to allocate the initial buffers: In order to validate that the MPS implementation matches the CPU implementation we'll need a 128GB config. This is because the output tensors in the test case that fails are around ~50GB each leading to my 96GB config killing the process when asserting that the CPU and MPS outputs are the same. This can however be worked around by reducing the tensors a bit, batch_size = 11 is just enough to take the new code path. That seems to be able to run all the way including the result validation on a 96GB config. @kulinseth @malfet Do we want guarantee our mac configurations now and going forwards are at least 96GB ones? Or should we try to work around this in some other way? The first I can think of would be to add a skip / xFail decorator if we detect that the test is running on a config with less than 96GB. I'm thinking it might have to be a skip since if the python process hogs too much memory it looks like the OS just kills it off instead of the test just failing due to some assert for not being able to allocate more memory. |
@jhavukainen I'm a bit confused about the repro case, because llama-3.1 should never run in full 32-bit mode, as it will be extremely slow and at least for 8B model would have zero advantage over bfloat16 variant that has a 2x smaller memory footprint) |
Well yes that is a bit confusing. But it comes directly from the developer's provided repro script where their llama v3.1 pipeline gives
and leads to the op that fails being also of type f32 But now that you mentioned it I tested the behavior when running with the half data types and there is clearly some other issue we need to tackle there since the MPSMatrixMultplication is doing some op underneath in those cases, leading to the same error in a slightly different context. And for bfloat16 specifically it raises an assert for unsupported data type so I'll need to consult the MPS team on workaround for that. While I wait for their feedback to see what should be done I'll add a check for the data type to raise the unsupported error for the half types for now. Since fp32 is the specific case the developer asked for this PR would still unblock them. |
So the FP16 issue happens on the pytorch side by us since we are calling |
indexing limit. This change enables operation to run if we can split it batchwise to multiple encodes avoiding the indexing limit.
…n expanding it to FP16 and BF16.
289879f
to
60b3b22
Compare
Updated the PR to work also on float16 and bfloat16. With this the LLaMA3.1 works on all the supported float data types. Now I'll just need to figure a precompiler directive to hide it from pre-MacOS15 machines since the compiler will refuse the |
@malfet @kulinseth I ended up adding the tests for fp16 and bf16 since those have a smaller memory footprint. The tests will still take a quite a bit more time than many other unittests we have, around 2 minutes per data type on a powerful machine. This is mostly down to how long the CPU takes computing the reference value and doing the comparison. Let me know if this is a concern and if there are other ideas on what would be a better test. Other than that this PR should be ready to review. |
|
} | ||
|
||
static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { | ||
#if defined(__MAC_15_0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would never gets compiled, would it? (As our builders are still target MacOS-13 as min supported OS)
You'll need to use `if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. That's a bummer. Then we'll need to do this in a roundabout way of adding the API definitions that come up in MacOS15 to the code somewhere. Else the runtime check won't help as the compiler kills the build beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
MacOS15 now passing. |
@pytorchbot merge -f "MPS + Lint are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: PR #133430 has not been reviewed yet |
@pytorchbot merge -f "MPS + Lint are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…along batch axis (pytorch#133430) Fixes pytorch#131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert. Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it: ``` import torch device='mps' a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device) b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device) res = torch.bmm(a, b) ``` Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit. Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available. Pull Request resolved: pytorch#133430 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert.
Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it:
Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit.
Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available.