Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916

Open
aradley opened this issue Mar 28, 2024 · 2 comments
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@aradley
Copy link

aradley commented Mar 28, 2024

🐛 Describe the bug

I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).

import numpy as np
import torch
mps_device = torch.device("mps")

## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))

## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)

## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([   19165,    27061,    39165, ..., 79979029, 79987021, 79995171])

# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979029, 79987021, 79995171])

# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979032, 79987024, 79995168], device='mps:0')

Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?

If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.

Versions

torch version v0.2.1 and v0.2.0

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr

@malfet malfet added module: mps Related to Apple Metal Performance Shaders framework module: correctness (silent) issue that returns an incorrect result silently module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 28, 2024
@alpoge
Copy link

alpoge commented Mar 30, 2024

hey solved this :D! i'll put in a pull request w the fix v soon (sry im a bit slow, am a pure mathematician and thus never actually pr'ed before...)

basically the issue is that in the mpsGraph scatter operation the tensor that is written from, which in the case of this torch.where call (which turns out to be a torch.nonzero call) is a list of coordinates which are int32, is secretly being cast to a float32 behind the scenes. you'll notice that the outputs always have the top 24 bits correct! (and indeed casting an int to a float starts rounding it past 2^(24)!)

so all that is required is to split the coordinates tensor into two in the mpsGraph calls --- one modulo 2^(23), say, and one (integer-)divided by 2^(23), scatter those, and then add them back up

i should have the fix for this requested very soon!!! all credit to @Jckwind for spreading the word about this (and for getting me up to speed)!

hopefully a number of these other MPS arithmetic issues are related, we shall see...

@kulinseth
Copy link
Collaborator

Thanks @alpoge for the fix. We are looking into if there is a more efficient way to do where we can use all the int32 index range values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: correctness (silent) issue that returns an incorrect result silently module: mps Related to Apple Metal Performance Shaders framework module: 64-bit Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants