-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched matmul gives incorrect result on MPS devices #111634
Comments
Can reproduce it on my end regardless of PyTorch version, looks like an MetalPerformanceShaders.framework bug to me |
Can still reproduce it on Sonoma |
I tried to reproduce this issue and seems that this bug only manifest when the input sizes are above certain points. For example,
The bug does not show up for Env: |
This issue is fixed with ##116769
|
🐛 Describe the bug
When the dimensions are large enough, batched matmul gives the wrong answer on MPS devices.
Minimal example:
This should give a tensor of 0s, but it instead gives a tensor in which 50,505,735 of the entries are 1. If the operation is performed a second time with the same tensors, the number of 1s changes to 182,632,455.
The dimensions in this example are minimal, i.e. the code runs correctly if any of them is made any smaller.
Versions
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.5.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.27.7
Libc version: N/A
Python version: 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.5.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.1
[pip3] torch==2.1.0
[conda] numpy 1.26.1 pypi_0 pypi
[conda] torch 2.1.0 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @kadeng @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev
The text was updated successfully, but these errors were encountered: