Skip to content

Conversation

jhavukainen
Copy link
Collaborator

Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert.

Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it:

import torch
device='mps'
a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device)
b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device)
res = torch.bmm(a, b)

Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit.

Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available.

Copy link

pytorch-bot bot commented Aug 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133430

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Aug 14, 2024
@jhavukainen
Copy link
Collaborator Author

Need some guidance on how do we want to test this:

In order to reach the original failure a 96GB configuration is required. With a 64GB config we will hit another error when trying to allocate the initial buffers: RuntimeError: Invalid buffer size: 47.99 GB. This is due to the max memory fraction pytorch is allowed.

In order to validate that the MPS implementation matches the CPU implementation we'll need a 128GB config. This is because the output tensors in the test case that fails are around ~50GB each leading to my 96GB config killing the process when asserting that the CPU and MPS outputs are the same. This can however be worked around by reducing the tensors a bit, batch_size = 11 is just enough to take the new code path. That seems to be able to run all the way including the result validation on a 96GB config.

@kulinseth @malfet Do we want guarantee our mac configurations now and going forwards are at least 96GB ones? Or should we try to work around this in some other way? The first I can think of would be to add a skip / xFail decorator if we detect that the test is running on a config with less than 96GB. I'm thinking it might have to be a skip since if the python process hogs too much memory it looks like the OS just kills it off instead of the test just failing due to some assert for not being able to allocate more memory.

@malfet
Copy link
Contributor

malfet commented Aug 15, 2024

@jhavukainen I'm a bit confused about the repro case, because llama-3.1 should never run in full 32-bit mode, as it will be extremely slow and at least for 8B model would have zero advantage over bfloat16 variant that has a 2x smaller memory footprint)

@jhavukainen
Copy link
Collaborator Author

@jhavukainen I'm a bit confused about the repro case, because llama-3.1 should never run in full 32-bit mode, as it will be extremely slow and at least for 8B model would have zero advantage over bfloat16 variant that has a 2x smaller memory footprint)

Well yes that is a bit confusing. But it comes directly from the developer's provided repro script where their llama v3.1 pipeline gives

print(pipeline.torch_dtype)
# torch.float32

and leads to the op that fails being also of type f32 aten::bmm_out_mps_impl:f32[32,20064,128]:f32[32,128,20064].

But now that you mentioned it I tested the behavior when running with the half data types and there is clearly some other issue we need to tackle there since the MPSMatrixMultplication is doing some op underneath in those cases, leading to the same error in a slightly different context. And for bfloat16 specifically it raises an assert for unsupported data type so I'll need to consult the MPS team on workaround for that.

While I wait for their feedback to see what should be done I'll add a check for the data type to raise the unsupported error for the half types for now. Since fp32 is the specific case the developer asked for this PR would still unblock them.

@jhavukainen
Copy link
Collaborator Author

So the FP16 issue happens on the pytorch side by us since we are calling mps::copy_cast_mps with a tensor that is too large for MPS to handle. That is something should be able to address.

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 16, 2024
indexing limit. This change enables operation to run if we can split it
batchwise to multiple encodes avoiding the indexing limit.
@jhavukainen
Copy link
Collaborator Author

Updated the PR to work also on float16 and bfloat16. With this the LLaMA3.1 works on all the supported float data types. Now I'll just need to figure a precompiler directive to hide it from pre-MacOS15 machines since the compiler will refuse the MPSNDArrayDescriptor.preferPackedRows on those machines.

@jhavukainen
Copy link
Collaborator Author

@malfet @kulinseth I ended up adding the tests for fp16 and bf16 since those have a smaller memory footprint. The tests will still take a quite a bit more time than many other unittests we have, around 2 minutes per data type on a powerful machine. This is mostly down to how long the CPU takes computing the reference value and doing the comparison. Let me know if this is a concern and if there are other ideas on what would be a better test.

Other than that this PR should be ready to review.

@malfet
Copy link
Contributor

malfet commented Aug 29, 2024

Other than that this PR should be ready to review.
@jhavukainen isn't it just failed on MacOS 15, see https://github.com/pytorch/pytorch/actions/runs/10620567384/job/29441019041

}

static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
#if defined(__MAC_15_0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would never gets compiled, would it? (As our builders are still target MacOS-13 as min supported OS)
You'll need to use `if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's a bummer. Then we'll need to do this in a roundabout way of adding the API definitions that come up in MacOS15 to the code somewhere. Else the runtime check won't help as the compiler kills the build beforehand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@jhavukainen
Copy link
Collaborator Author

MacOS15 now passing.

@malfet
Copy link
Contributor

malfet commented Aug 30, 2024

@pytorchbot merge -f "MPS + Lint are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: PR #133430 has not been reviewed yet

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@malfet
Copy link
Contributor

malfet commented Aug 30, 2024

@pytorchbot merge -f "MPS + Lint are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…along batch axis (pytorch#133430)

Fixes pytorch#131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert.

Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it:

```
import torch
device='mps'
a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device)
b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device)
res = torch.bmm(a, b)
```

Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit.

Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available.
Pull Request resolved: pytorch#133430
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
@github-actions github-actions bot deleted the dev/joona/tile_bmm branch October 3, 2024 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MPS backend breaking on llama 3.1 8B on Macbook M3

5 participants