Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.multinomial on MPS crashes with Error: total bytes of NDArray > 2**32' #86279

Closed
malfet opened this issue Oct 5, 2022 · 2 comments
Closed
Labels
module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@malfet
Copy link
Contributor

malfet commented Oct 5, 2022

馃悰 Describe the bug

After #80760 added MPS version of multinomial op, operations with replacement fails for arrays of more than 32K elements with non-recoverable error:

 % python -c "import torch;print(torch.multinomial(torch.ones(1, 32768, device='mps'), 2, replacement=True))"
/AppleInternal/Library/BuildRoots/4883e71d-37bd-11ed-b0ef-b25c5e9b9057/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'
zsh: abort      python -c 

Versions

Nightly

cc @kulinseth @albanD @DenisVieriu97 @razarmehr @abhudev

@malfet malfet added module: regression It used to work, and now it doesn't module: mps Related to Apple Metal Performance Shaders framework labels Oct 5, 2022
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 6, 2022
@malfet
Copy link
Contributor Author

malfet commented Jan 9, 2024

Please note, that this was fixed in MacOS 14...

@kulinseth
Copy link
Collaborator

This has been fixed with latest MacOS14+

(py39) kulinseth@Mac-308348 pytorch % python -c "import torch;print(torch.multinomial(torch.ones(1, 32768, device='mps'), 2, replacement=True))"
tensor([[6389, 8318]], device='mps:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants