Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS device can train or evaluate models producing unacceptable output due to "fast math" optimization #84936

Closed
mallman opened this issue Sep 13, 2022 · 30 comments
Assignees
Labels
module: mps Related to Apple Metal Performance Shaders framework module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mallman
Copy link

mallman commented Sep 13, 2022

馃悰 Describe the bug

I am getting radically different results running the CURL model on the cpu versus the mps device (on pytorch 1.12.1 and nightly). I stepped through the debugger to find the underlying difference in calculation, In short, it appears that mps is using the "fast math" version of the standard library, leading to unacceptable results for some models. This is not a bug with CURL.

Here's a minimal example that doesn't work the way I expect:

import torch

mps_device = torch.device("mps")

x = 0.1

cpu_tensor = torch.exp(torch.tensor(x))
mps_tensor = torch.exp(torch.tensor(x, device=mps_device))

print(cpu_tensor - cpu_tensor) # prints 0
print(mps_tensor - mps_tensor) # prints 0
print(cpu_tensor - mps_tensor) # prints 1.1921e-07
print(cpu_tensor - mps_tensor.cpu()) # prints 1.1921e-07
print(cpu_tensor.to(mps_device) - mps_tensor) # prints 1.1921e-07

Versions

Here's my pytorch 1.12.1 environment info:

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.102)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00)  [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] adabelief-pytorch==0.2.1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] torch==1.12.1
[pip3] torch-tb-profiler==0.4.0
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1
[pip3] torchviz==0.0.2
[conda] adabelief-pytorch         0.2.1                    pypi_0    pypi
[conda] numpy                     1.23.3                   pypi_0    pypi
[conda] torch                     1.12.1                   pypi_0    pypi
[conda] torch-tb-profiler         0.4.0                    pypi_0    pypi
[conda] torchmetrics              0.9.3                    pypi_0    pypi
[conda] torchvision               0.13.1                   pypi_0    pypi
[conda] torchviz                  0.0.2                    pypi_0    pypi

And here's my pytorch nightly env:

PyTorch version: 1.13.0.dev20220913
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.102)
CMake version: version 3.23.2
Libc version: N/A

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00)  [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] adabelief-pytorch==0.2.1
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] torch==1.13.0.dev20220913
[pip3] torch-tb-profiler==0.4.0
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1
[pip3] torchviz==0.0.2
[conda] adabelief-pytorch         0.2.1                    pypi_0    pypi
[conda] numpy                     1.23.3                   pypi_0    pypi
[conda] torch                     1.13.0.dev20220913          pypi_0    pypi
[conda] torch-tb-profiler         0.4.0                    pypi_0    pypi
[conda] torchmetrics              0.9.3                    pypi_0    pypi
[conda] torchvision               0.13.1                   pypi_0    pypi
[conda] torchviz                  0.0.2                    pypi_0    pypi
@ngimel
Copy link
Collaborator

ngimel commented Sep 13, 2022

1e-7 is expected due to slightly different fp precision

@ngimel ngimel added module: mps Related to Apple Metal Performance Shaders framework module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 13, 2022
@mallman
Copy link
Author

mallman commented Sep 14, 2022

1e-7 is expected due to slightly different fp precision

I thought they were both fp32?

@mallman
Copy link
Author

mallman commented Sep 14, 2022

To better illustrate the problem that motivated me to post this issue, I can present some output from CURL. This first image is from the CPU:

a2803-060810_075208_GM6A0020_input_TEST_1_1_PSNR_27 756_SSIM_0 982

This is from the mps device:

a2803-060810_075208_GM6A0020_input_TEST_1_1_PSNR_23 659_SSIM_0 968

The difference in output comes down to a cascade of tiny differences in fp values that avalanche into large differences. I think these calculations need to be exactly the same on the cpu and mps device to ensure correctness. Otherwise, I don't see how the mps device can be relied on.

@tmm1
Copy link

tmm1 commented Sep 16, 2022

Is a PR to flip these to NO the correct next step here?

[options setFastMathEnabled: YES];

[options setFastMathEnabled:YES];

[options setFastMathEnabled: YES];

@mallman
Copy link
Author

mallman commented Sep 16, 2022

Is a PR to flip these to NO the correct next step here?

I've looked into this as well, and I don't think this will make a difference. In fact, IIRC I recompiled PyTorch with these set to NO, and there was no change in behavior.

As far as I can tell, PyTorch's MPS device is backed by Apple's MPSGraph framework. What we need is an ability to "recompile" or select a version of MPSGraph that has fast math disabled, assuming my hypothesis that it's enabled for those shaders is correct.

Unfortunately, like many of Apple's frameworks, MPSGraph is closed source. Otherwise, I'd be digging in there already.

I know there are some Apple engineers that work on this codebase. I believe they will be able to diagnose and address this issue.

Hmmm... this makes me wonder if there's a way to inspect whether a shader has been compiled with fast math enabled without having access to its source code... hmmm...

@mallman
Copy link
Author

mallman commented Oct 6, 2022

The attached file is an Xcode project which demonstrates that the value of exp(0.1) computed by the MPSGraph API is the same as the value computed by the "fast" variant of the standard Metal exp function. I believe this pretty much confirms that the MPSGraph kernels are compiled with Metal "fast math" enabled, breaking IEEE 754 conformance. This is going to require a fix from Apple. I've submitted a bug report through Apple's "Feedback Assistant" including the attached project. I will post back here with an update when I hear from Apple.

For any Apple engineers watching this issue, my feedback report number is FB11657311.

MPSImprecision.zip

@mallman
Copy link
Author

mallman commented Oct 10, 2022

The attached file is an Xcode project which demonstrates that the value of exp(0.1) computed by the MPSGraph API is the same as the value computed by the "fast" variant of the standard Metal exp function. I believe this pretty much confirms that the MPSGraph kernels are compiled with Metal "fast math" enabled, breaking IEEE 754 conformance. This is going to require a fix from Apple. I've submitted a bug report through Apple's "Feedback Assistant" including the attached project. I will post back here with an update when I hear from Apple.

For any Apple engineers watching this issue, my feedback report number is FB11657311.

MPSImprecision.zip

I need to refine my assessment of the problem. The problem is not with IEEE 754 conformance. Rather, I'd suggest the problem is that the implementation of certain transcendental/elementary functions with Metal's "fast math" library produce different results for the same inputs compared to the standard C libm library. So I think the fix is to ensure identity in output between the MPSGraph implementations of libm functions and the CPU implementations of the libm functions themselves. And again, this is an issue for Apple.

@mallman
Copy link
Author

mallman commented Jun 5, 2023

@kulinseth This issue was reported for macOS 13. Now that macOS Sonoma (14) beta has been released, can you comment on whether this issue has been (or will be) addressed in that version? I'm unlikely to install the Sonoma beta myself for the time being, but am hopeful we will have a solution in this release.

@kulinseth
Copy link
Collaborator

@kulinseth This issue was reported for macOS 13. Now that macOS Sonoma (14) beta has been released, can you comment on whether this issue has been (or will be) addressed in that version? I'm unlikely to install the Sonoma beta myself for the time being, but am hopeful we will have a solution in this release.

@mallman , we have fixed the "fast-math" issue even on MacOS Ventura. Can you please try 13.4 Ventura and see if the issue still persists?

@mallman
Copy link
Author

mallman commented Jun 8, 2023

@kulinseth This issue was reported for macOS 13. Now that macOS Sonoma (14) beta has been released, can you comment on whether this issue has been (or will be) addressed in that version? I'm unlikely to install the Sonoma beta myself for the time being, but am hopeful we will have a solution in this release.

@mallman , we have fixed the "fast-math" issue even on MacOS Ventura. Can you please try 13.4 Ventura and see if the issue still persists?

I tried the python script in this issue's description, and I tried the Xcode project I attached (MPSImprecision.zip). Both still show the same difference in calculation between cpu and mps. I'm using Venture 13.4, Xcode 14.3.1 and Pytorch 2.0.0.

So, yes, the issue still persists. Are you seeing different behavior on your end?

FWIW, my feedback report (FB11657311) has been marked as "Potential fix identified - For a future OS update". That hasn't changed in a long time (not sure when).

@mallman
Copy link
Author

mallman commented Jun 30, 2023

@kulinseth This issue was reported for macOS 13. Now that macOS Sonoma (14) beta has been released, can you comment on whether this issue has been (or will be) addressed in that version? I'm unlikely to install the Sonoma beta myself for the time being, but am hopeful we will have a solution in this release.

@mallman , we have fixed the "fast-math" issue even on MacOS Ventura. Can you please try 13.4 Ventura and see if the issue still persists?

@kulinseth As I indicated, this issue still exists in macOS 13.4. Can you validate that this issue is fixed (or will be fixed) in Sonoma? It is a significant impediment to using the mps device where precise definitions of transcendental functions is required.

@KouseiHongqing
Copy link

I encountered the same issue, where there was a significant discrepancy between the final results when running my network on MPS compared to the nvGPU and CPU versions. Upon debugging, I found that there were errors in every step of the inference in the transformer layer, which resulted in very poor final results. This issue is very severe and directly affects the viability of using MPS.

@mallman
Copy link
Author

mallman commented Sep 16, 2023

I encountered the same issue, where there was a significant discrepancy between the final results when running my network on MPS compared to the nvGPU and CPU versions. Upon debugging, I found that there were errors in every step of the inference in the transformer layer, which resulted in very poor final results. This issue is very severe and directly affects the viability of using MPS.

I agree. There are two things I'm aware of which are holding back MPS from perfect (or close enough) training/prediction fidelity with CPU in Pytorch.

  1. Apple's MPSGraph framework gives all the evidence that it was built with Metal "fast math". The tolerances for the "fast math" pragma are significantly looser than without it. That's what this particular issue highlights.
  2. Metal does not support a 64-bit floating point data type.

I would say point 1 is more important, and easier to fix. Point 2 is relevant in some but not all models. In some cases, lower precision is perfectly acceptable and is desirable for savings in execution time and storage. It would be nice to see an implementation of bfloat16, as well. (Actually it looks like bloat16 is coming to Sonoma, at least on supported hardware.) I'm very pessimistic about Apple supporting a 64-bit fp type in Metal anytime soon. I believe their current GPU hardware is limited to 32-bit fp precision, and the neural processor is 16-bit.

Incidentally, @KouseiHongqing, what floating point precision are you using to run/train your network on cpu and nvgpu? (I have very little experience with nVidia accelerators, but I believe they support fp64.) It would be interesting to see if you get acceptable results using 32-bit floating point on cpu and nvidia. At least then you can assume you won't need metal 64-bit floating point support. If you get inadequate results with fp32, then addressing the issue with MPS "fast math" likely won't help your use case.

As you can see from earlier in this thread, an Apple engineer claimed this issue had been fixed in Ventura. Even though my testing showed it hadn't been fixed, I hope this means they are at least taking this issue seriously.

Cheers.

@KouseiHongqing
Copy link

I encountered the same issue, where there was a significant discrepancy between the final results when running my network on MPS compared to the nvGPU and CPU versions. Upon debugging, I found that there were errors in every step of the inference in the transformer layer, which resulted in very poor final results. This issue is very severe and directly affects the viability of using MPS.

I agree. There are two things I'm aware of which are holding back MPS from perfect (or close enough) training/prediction fidelity with CPU in Pytorch.

1. Apple's MPSGraph framework gives all the evidence that it was built with Metal "fast math". The tolerances for the "fast math" pragma are significantly looser than without it. That's what this particular issue highlights.

2. Metal does not support a 64-bit floating point data type.

I would say point 1 is more important, and easier to fix. Point 2 is relevant in some but not all models. In some cases, lower precision is perfectly acceptable and is desirable for savings in execution time and storage. It would be nice to see an implementation of bfloat16, as well. (Actually it looks like bloat16 is coming to Sonoma, at least on supported hardware.) I'm very pessimistic about Apple supporting a 64-bit fp type in Metal anytime soon. I believe their current GPU hardware is limited to 32-bit fp precision, and the neural processor is 16-bit.

Incidentally, @KouseiHongqing, what floating point precision are you using to run/train your network on cpu and nvgpu? (I have very little experience with nVidia accelerators, but I believe they support fp64.) It would be interesting to see if you get acceptable results using 32-bit floating point on cpu and nvidia. At least then you can assume you won't need metal 64-bit floating point support. If you get inadequate results with fp32, then addressing the issue with MPS "fast math" likely won't help your use case.

As you can see from earlier in this thread, an Apple engineer claimed this issue had been fixed in Ventura. Even though my testing showed it hadn't been fixed, I hope this means they are at least taking this issue seriously.

Cheers.

@mallman I am using PyTorch's mixed-precision computing, which includes FP16 and FP32. However, they both have varying degrees of errors, and I don't think it's due to fast math. As an alternative, I am running inference using the CPU on my Mac Studio, which is three times slower than running it with MPS.
I am saddened to see that this issue has not been resolved after one year. I hope the officials will fix this problem as soon as possible.
Cheers

@KouseiHongqing
Copy link

In the latest version of PyTorch (nightly0918), my precision issue has been resolved. Many thanks to the support.

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

@atabakd
Copy link
Contributor

atabakd commented Oct 8, 2023

I am following this tutorial on diffusion models and even just adding noise to the images, the results are completely different in cpu vs mps on Sonoma 14.0 (23A344) M1 max
image

image
pytorch                   2.2.0.dev20231008         py3.9_0    pytorch-nightly
torchaudio                2.2.0.dev20231008        py39_cpu    pytorch-nightly
torchvision               0.17.0.dev20231008        py39_cpu    pytorch-nightly

@kulinseth
Copy link
Collaborator

I am following this tutorial on diffusion models and even just adding noise to the images, the results are completely different in cpu vs mps on Sonoma 14.0 (23A344) M1 max image

image ``` pytorch 2.2.0.dev20231008 py3.9_0 pytorch-nightly torchaudio 2.2.0.dev20231008 py39_cpu pytorch-nightly torchvision 0.17.0.dev20231008 py39_cpu pytorch-nightly ```

@atabakd , can you please file a separate issue?

@mallman
Copy link
Author

mallman commented Nov 8, 2023

I'm genuinely mystified why this issue hasn't been resolved. Is the mps codebase plagued by such numerous or complex bugs that this one can't be addressed in a year's time? Is this not so easily fixed?

I'm not suggesting someone else fix a problem I could. I've actually put in my due diligence before even filing this issue, looking for a fix in Pytorch. And I've always been happy to contribute fixes to open source software, but this mps integration is with closed source software from an opaque partner organization. As far as anyone outside Apple knows, we are as close or as far to a resolution to this bug as we were a year ago.

Am I being unfair in expressing my frustration?

@aradley
Copy link

aradley commented Mar 19, 2024

I think I have a related and very simple example of how MPS can get completely different results from CPU. Hopefully the simplicity of this issues will be clear and helpful.

import numpy as np
import torch
mps_device = torch.device("mps")

## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))

## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)

## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([   19165,    27061,    39165, ..., 79979029, 79987021, 79995171])

# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979029, 79987021, 79995171])

# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([   19165,    27061,    39165,  ..., 79979032, 79987024, 79995168], device='mps:0')

Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?

If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.

@MaratZakirov
Copy link

MaratZakirov commented May 6, 2024

I had issue not with numerical stability (as I initially thought) but with MPS memory. When didn't immediately convert data to CPU and instead collected it into list it tend to zero MPS memory. So if you work with MPS device you should do conversion asap. Just do your calculations (in which MPS is really great and saves ton of time) and just go back to CPU mode.

@ezyang ezyang changed the title CPU and MPS floating point math is different (in a significant way) torch.where is incorrect on MPS when there are 2**24 elements in tensor May 13, 2024
@malfet
Copy link
Contributor

malfet commented May 15, 2024

@atabakd , can you please file a separate issue?

This one was fixed a while back, trying to find a PR number, but it had something to do with signed vs unsigned dtypes

@albanD
Copy link
Collaborator

albanD commented May 15, 2024

I think there are quite a different issues being discussed here. To help with tracking, I think that:

  • The original report at the very top about ~1e-7 difference between cpu and mps is expected. This is true for all devices and even happen for different threading setting on cpu etc. See https://pytorch.org/docs/main/notes/numerical_accuracy.html and other related notes there for details.
  • The two other issues about rescale and torch.where were respectively addressed and moved to a separate specialized issue.

So I think we can close this issue?

@mallman
Copy link
Author

mallman commented May 23, 2024

I think there are quite a different issues being discussed here. To help with tracking, I think that:

This issue isn't related to the document you reference. This is an issue about MPS using imprecise implementations of transcendental functions (so-called "fast math" from the Metal standard library).

The title of this issue is misleading. I will fix that.

  • The two other issues about rescale and torch.where were respectively addressed and moved to a separate specialized issue.

So I think we can close this issue?

IMO, we can close this issue for two reasons:

  1. The MPS library which Pytorch uses is recompiled without the "fast math" flag, or,
  2. We acknowledge/document that the Pytorch mps device can produce relatively inaccurate output because it uses relatively imprecise implementations of transcendental functions, such as exp. Models which use the metal functions from the "fast math" implementation of the standard library should not be expected to produce acceptable results.

Of course, I have a preference for (1), but the implementation is opaque (closed source), so there's really nothing anyone can do about this unless they have access to the source code at Apple.

Alternatively, we can decide/discover a different underlying reason for this problem. That would require someone with access to the source code for MPS, i.e. an Apple employee.

@mallman mallman changed the title torch.where is incorrect on MPS when there are 2**24 elements in tensor MPS device can train or evaluate models with producing unacceptable output due to "fast math" optimization May 23, 2024
@mallman mallman changed the title MPS device can train or evaluate models with producing unacceptable output due to "fast math" optimization MPS device can train or evaluate models producing unacceptable output due to "fast math" optimization May 23, 2024
@mallman
Copy link
Author

mallman commented May 23, 2024

I think there are quite a different issues being discussed here. To help with tracking, I think that:

  • The original report at the very top about ~1e-7 difference between cpu and mps is expected. This is true for all devices and even happen for different threading setting on cpu etc. See https://pytorch.org/docs/main/notes/numerical_accuracy.html and other related notes there for details.
  • The two other issues about rescale and torch.where were respectively addressed and moved to a separate specialized issue.

So I think we can close this issue?

I think another possible source of confusion around this specific Pytorch issue is that many other Pytorch users added cases of bugs they encountered, some of which were apparently not caused by or related to the "fast math" setting for MPS.

For example, this issue is not about "torch.where is incorrect on MPS when there are 2**24 elements in tensor".

@malfet
Copy link
Contributor

malfet commented May 28, 2024

The MPS library which Pytorch uses is recompiled without the "fast math" flag,

Alternative approach would be to provide an explicit Metal implementations for the operators, which must be compiled without a fast math. But I'm curious to see a bit more justification and tradeoffs for doing so. Or, it could be an opt-in for an MPS backend (i.e. torch.backends.mps.use_fast_math = False will dispatch to that kernel, this way we don't need to justify a perf drop)

@mallman
Copy link
Author

mallman commented May 29, 2024

The MPS library which Pytorch uses is recompiled without the "fast math" flag,

Alternative approach would be to provide an explicit Metal implementations for the operators, which must be compiled without a fast math. But I'm curious to see a bit more justification and tradeoffs for doing so. Or, it could be an opt-in for an MPS backend (i.e. torch.backends.mps.use_fast_math = False will dispatch to that kernel, this way we don't need to justify a perf drop)

Justification:

#84936 (comment)

The task of this algorithm is image enhancement. The second picture clearly has a severe purplish cast that makes for a qualitatively inferior, if not bizarre, result.

What good is "fast" if it "quickly" gets you a bad result? When I use a program for scientific computing (or any computation, really), I want to have confidence that I will get good results. Really, I think the question should be: what is the justification for "fast" but "bad" results?

@albanD
Copy link
Collaborator

albanD commented May 30, 2024

What good is "fast" if it "quickly" gets you a bad result? When I use a program for scientific computing (or any computation, really), I want to have confidence that I will get good results. Really, I think the question should be: what is the justification for "fast" but "bad" results?

Within PyTorch, we don't have any justification for it and we always want all operations to be as precise as specified by the format standard. We do disable by default any low precision compute even at the cost of performance (see tf32 discussion for the latest major event there).

The problem here as I understand is that the MPS library that is not under our control and has a very long release cycle is not providing us with the right tools. So our short term options are to rewrite a new kernel from scratch, have that function error out or have an imprecise kernel.

@kulinseth can you confirm what is the long term plan here and if it is indeed to ensure that all kernels in MPS will be compiled without fast-math (or at least will have a version without fast math that can be used here)?

@malfet
Copy link
Contributor

malfet commented May 31, 2024

I have no idea how to estimate errors by looking at images, but I can write code. For example running the following, shows that both CPU, MPS, Metal fast and Metal precise yield a slightly different results:

import Metal
import MetalPerformanceShadersGraph


func calculateExpMetal(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int, fastMathEnabled: Bool = false) {
  let shader_source = """
  #include <metal_stdlib>
  using namespace metal;

  kernel void do_exp(constant float *input [[buffer(0)]],
                     device float *output [[buffer(1)]],
                     uint thread_index [[thread_position_in_grid]]) {
    output[thread_index] = exp(input[thread_index]);
  }
  """
  let options = MTLCompileOptions()
  options.languageVersion = .version3_1
  options.fastMathEnabled = fastMathEnabled
  let library = try! device.makeLibrary(source:shader_source, options:options)
  guard let mfunc = library.makeFunction(name: "do_exp") else { fatalError("Can't find function") }
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
  guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
  computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
  computeEncoder.setBuffer(ibuf, offset:0, index: 0)
  computeEncoder.setBuffer(obuf, offset:0, index: 1)
  computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
  computeEncoder.endEncoding()
  cmdBuffer.commit()
  cmdBuffer.waitUntilCompleted()
}

func calculateExpMPS(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int) {
  let graph = MPSGraph()
  let inputPlaceholder = graph.placeholder(shape: [nelem as NSNumber], dataType: .float32, name: nil)
  let expNode = graph.exponent(with: inputPlaceholder, name: nil)
  let mpsInputBuffer = MPSGraphTensorData(ibuf, shape: [nelem as NSNumber], dataType: .float32)
  let mpsOutputBuffer = MPSGraphTensorData(obuf, shape: [nelem as NSNumber], dataType: .float32)
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer], targetOperations: nil, resultsDictionary: [expNode: mpsOutputBuffer])
}

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }

let nelem = 256
guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
let ibuf_data = ibuf.contents().assumingMemoryBound(to: Float.self)
for i in 0..<nelem {
    ibuf_data[i] = log(Float(i)*0.1 + 0.1)
}

guard let obuf_fast = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let obuf_prec = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
guard let obuf_mps = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
calculateExpMPS(device: device, ibuf: ibuf, obuf: obuf_mps, nelem: nelem)
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_fast, nelem: nelem, fastMathEnabled: true)
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_prec, nelem: nelem, fastMathEnabled: false)

let obuf_fast_data = obuf_fast.contents().assumingMemoryBound(to: Float.self)
let obuf_prec_data = obuf_prec.contents().assumingMemoryBound(to: Float.self)
let obuf_mps_data = obuf_mps.contents().assumingMemoryBound(to: Float.self)
print("i, prec_cpu, fast_cpu, mtl_fast_prec, mps_prec, mps_fast")
for i in 0..<nelem {
    let cpu_exp = exp(ibuf_data[i])
    let fast_prec_diff = abs(obuf_fast_data[i] - obuf_prec_data[i])
    let mps_prec_diff = abs(obuf_mps_data[i] - obuf_prec_data[i])
    let mps_fast_diff = abs(obuf_mps_data[i] - obuf_fast_data[i])
    let prec_cpu_diff = abs(obuf_prec_data[i] - cpu_exp)
    let fast_cpu_diff = abs(obuf_fast_data[i] - cpu_exp)
    print("\(i), \(prec_cpu_diff), \(fast_cpu_diff), \(fast_prec_diff), \(mps_prec_diff), \(mps_fast_diff)")
}

And MPS results are closer to precise (on MacOS 14.5) rather than to fast, as one can see from the following chart:
image

Next steps: Implement a canonical function that compute exp via long Taylor series
And if anyone knows if there is a IEEE standard for exp, please share it here

@kulinseth
Copy link
Collaborator

kulinseth commented Jun 10, 2024

IMO, we can close this issue for two reasons:

  1. The MPS library which Pytorch uses is recompiled without the "fast math" flag, or,

Of course, I have a preference for (1), but the implementation is opaque (closed source), so there's really nothing anyone can do about this unless they have access to the source code at Apple.

Alternatively, we can decide/discover a different underlying reason for this problem. That would require someone with access to the source code for MPS, i.e. an Apple employee.

Hi @mallman, MPS is a framework in OS which has been updated to use metal::precise for most of the ops which need higher precision. PyTorch is using the same framework and all the operations which are using MPS kernels should see this increased precision since Ventura 13.4 or Sonoma OS version. Following are the operations which i can confirm have been updated:

Ops Notes (high precision == metal::precise)
FourierTransform cospi / sinpi are in high precision
Random cospi / sinpi are in high precision
ATan2 high precision
Pow high precision
Exp Custom precision (Higher than metal::fast but lower than metal::precise)
Exp2 high precision
Exp10 Custom precision (Higher than metal::fast but lower than metal::precise)
logarithm high precision
logarithmBase2 high precision
logarithmBase10 high precision
squareRoot high precision
reverseSquareRoot high precision
ceilOp high precision
floorOp high precision
roundOp high precision
sinOp high precision
cosOp high precision
tanOp high precision
sinhOp high precision
coshOp high precision
asinOp high precision
acosOp high precision
atanOp high precision
asinhOp high precision
acoshOp high precision
atanhOp high precision

For perf reasons these were kept in Fast precise mode

Ops Notes (fast precision == metal::fast)
square fast precision
reciprocal fast precision
signOp fast precision
signbitOp fast precision
tanhOp fast precision
erf fast precision
Complex type operations fast precision

As @malfet said, if there are use-cases we can add Metal kernels in precise mode. The image is a good example but we would need a use-case which shows Training convergence error or a problem with inference which is caused by numerical issue. I am happy to pre-emptively move the exp and tanh to Precise mode. Let me propose a PR for it.

malfet added a commit that referenced this issue Jun 11, 2024
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)

Fix bug in non-contiguous tensors handling

Fixes #84936 

[ghstack-poisoned]
@mallman
Copy link
Author

mallman commented Jun 13, 2024

@kulinseth Thank you for addressing this issue. I will let report back if the underlying problem is not fixed.

Cheers.

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this issue Jun 14, 2024
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU

Fix bug in non-contiguous tensors handling

Fixes pytorch#84936
Pull Request resolved: pytorch#128421
Approved by: https://github.com/kulinseth
ghstack dependencies: pytorch#128373, pytorch#128375
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this issue Jun 14, 2024
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU

Fix bug in non-contiguous tensors handling

Fixes pytorch#84936
Pull Request resolved: pytorch#128421
Approved by: https://github.com/kulinseth
ghstack dependencies: pytorch#128373, pytorch#128375
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests