Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched matmul gives incorrect result on MPS devices #111634

Closed
jacobhilton opened this issue Oct 20, 2023 · 4 comments
Closed

Batched matmul gives incorrect result on MPS devices #111634

jacobhilton opened this issue Oct 20, 2023 · 4 comments
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jacobhilton
Copy link

jacobhilton commented Oct 20, 2023

🐛 Describe the bug

When the dimensions are large enough, batched matmul gives the wrong answer on MPS devices.

Minimal example:

import torch

zeros = torch.zeros(911, 9, 1, device=torch.device("mps"))
ones = torch.ones(1, 32769, device=torch.device("mps"))
zeros @ ones

This should give a tensor of 0s, but it instead gives a tensor in which 50,505,735 of the entries are 1. If the operation is performed a second time with the same tensors, the number of 1s changes to 182,632,455.

The dimensions in this example are minimal, i.e. the code runs correctly if any of them is made any smaller.

Versions

PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.5.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.27.7
Libc version: N/A

Python version: 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.5.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.1
[pip3] torch==2.1.0
[conda] numpy 1.26.1 pypi_0 pypi
[conda] torch 2.1.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

@malfet malfet added module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework high priority labels Oct 23, 2023
@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 23, 2023
@malfet
Copy link
Contributor

malfet commented Oct 31, 2023

Can reproduce it on my end regardless of PyTorch version, looks like an MetalPerformanceShaders.framework bug to me

@malfet
Copy link
Contributor

malfet commented Dec 19, 2023

Can still reproduce it on Sonoma

@Essoz
Copy link

Essoz commented Feb 18, 2024

I tried to reproduce this issue and seems that this bug only manifest when the input sizes are above certain points.

For example,

zeros = torch.zeros(size_zero, 1, device=torch.device("mps"))
ones = torch.ones(1, 32769, device=torch.device("mps"))

The bug does not show up for size_zero < 1025, and happen indeterministically when size_zero >= 1025 (about 90% probability).

Env:
MacBook Air M2 16 GiB RAM.
Sonoma 14.2.1
torch 2.0.0

@kulinseth
Copy link
Collaborator

This issue is fixed with ##116769

>>> import torch
>>> zeros = torch.zeros(911, 9, 1, device=torch.device("mps"))
>>> ones = torch.ones(1, 32769, device=torch.device("mps"))
>>> zeros @ ones
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='mps:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants