MPS: add environment variable, PYTORCH_MPS_GEMM_PREFER_FAST_MATH, allowing combination of MPS Fast mode and Metal SDPA kernels for optimal performance balance #167424
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When testing the opt-in Metal kernels and MPS Fast mode, I found that GEMM performance is superior under MPS implementation, but SDPA performs best with the Metal implementation.
I've introduced and environment variable
PYTORCH_MPS_GEMM_PREFER_FAST_MATHthat can be enabled alongside MPS Fast math and Metal kernels to inform torch that GEMM should prefer MPS.A small benchmark:
Fast-math GEMM run (env PYTORCH_MPS_FAST_MATH=1 python3 scripts/mps_microbench.py --benchmarks gemm):
• 512³: 0.00093 s (0.29 TFLOP/s)
• 1024³: 0.00118 s (1.82 TFLOP/s)
• 2048³: 0.00317 s (5.42 TFLOP/s)
Metal-only SDPA run (env PYTORCH_MPS_PREFER_METAL=1 python3 scripts/mps_microbench.py --benchmarks sdpa):
• B1/H8/L512/D64: 0.00092 s (0.28 GTokens/s)
• B2/H16/L1024/D64: 0.00380 s (0.55 GTokens/s)
Measured averages (float16, 8 iterations GEMM / 6 SDPA):
• Vanilla: GEMM 2048³ ≈ 0.00254 s (6.76 TFLOP/s); SDPA B2/H16/L1024/D64 ≈ 0.00317 s (0.66 GTokens/s).
• Fast Math: GEMM 2048³ ≈ 0.00240 s (7.16 TFLOP/s); SDPA ≈ 0.00180 s (1.16 GTokens/s).
• Prefer Metal: GEMM 2048³ ≈ 0.00973 s (1.77 TFLOP/s); SDPA ≈ 0.00172 s (1.22 GTokens/s).