MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916
Labels
module: correctness (silent)
issue that returns an incorrect result silently
module: mps
Related to Apple Metal Performance Shaders framework
module: 64-bit
Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).
Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?
If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.
Versions
torch version v0.2.1 and v0.2.0
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr
The text was updated successfully, but these errors were encountered: