Skip to content

Conversation

@bghira
Copy link

@bghira bghira commented Nov 9, 2025

When testing the opt-in Metal kernels and MPS Fast mode, I found that GEMM performance is superior under MPS implementation, but SDPA performs best with the Metal implementation.

I've introduced and environment variable PYTORCH_MPS_GEMM_PREFER_FAST_MATH that can be enabled alongside MPS Fast math and Metal kernels to inform torch that GEMM should prefer MPS.

image

A small benchmark:

#!/usr/bin/env python
"""Microbenchmark GEMM and SDPA performance on MPS/Metal backends."""

from __future__ import annotations

import argparse
import os
import statistics
import time
from dataclasses import dataclass
from typing import Callable, Iterable, List, Tuple

import torch
import torch.nn.functional as F


def _synchronize() -> None:
    if hasattr(torch, "mps") and torch.backends.mps.is_available():  # type: ignore[attr-defined]
        torch.mps.synchronize()
    elif torch.cuda.is_available():
        torch.cuda.synchronize()


def _benchmark(op: Callable[[], torch.Tensor], *, warmup: int, iterations: int) -> float:
    for _ in range(warmup):
        _ = op()
    _synchronize()

    samples: List[float] = []
    for _ in range(iterations):
        start = time.perf_counter()
        _ = op()
        _synchronize()
        samples.append(time.perf_counter() - start)
    return statistics.mean(samples)


@dataclass(frozen=True)
class GemmShape:
    m: int
    n: int
    k: int


@dataclass(frozen=True)
class SdpaConfig:
    batch: int
    heads: int
    seq_len: int
    head_dim: int


def run_gemm_benchmarks(device: torch.device, dtype: torch.dtype, shapes: Iterable[GemmShape]) -> List[Tuple[GemmShape, float]]:
    results: List[Tuple[GemmShape, float]] = []
    for shape in shapes:
        a = torch.randn(shape.m, shape.k, device=device, dtype=dtype)
        b = torch.randn(shape.k, shape.n, device=device, dtype=dtype)

        def op() -> torch.Tensor:
            return a @ b

        avg_time = _benchmark(op, warmup=3, iterations=8)
        results.append((shape, avg_time))
    return results


def run_sdpa_benchmarks(
    device: torch.device,
    dtype: torch.dtype,
    configs: Iterable[SdpaConfig],
) -> List[Tuple[SdpaConfig, float]]:
    results: List[Tuple[SdpaConfig, float]] = []
    for config in configs:
        q = torch.randn(config.batch, config.heads, config.seq_len, config.head_dim, device=device, dtype=dtype)
        k = torch.randn_like(q)
        v = torch.randn_like(q)

        def op() -> torch.Tensor:
            return F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)

        avg_time = _benchmark(op, warmup=2, iterations=6)
        results.append((config, avg_time))
    return results


def format_env_summary() -> str:
    tracked = (
        "PYTORCH_MPS_FAST_MATH",
        "PYTORCH_MPS_PREFER_METAL",
        "PYTORCH_ENABLE_MPS_FALLBACK",
    )
    parts = []
    for key in tracked:
        value = os.environ.get(key, "")
        parts.append(f"{key}={value or 'unset'}")
    return ", ".join(parts)


def main() -> None:
    parser = argparse.ArgumentParser(description="Benchmark GEMM and SDPA performance on MPS/Metal.")
    parser.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"], help="Tensor dtype to benchmark.")
    parser.add_argument(
        "--benchmarks",
        nargs="+",
        choices=["gemm", "sdpa"],
        default=["gemm", "sdpa"],
        help="Subset of benchmarks to run.",
    )
    args = parser.parse_args()

    if not torch.backends.mps.is_available():  # type: ignore[attr-defined]
        raise SystemExit("MPS backend is unavailable on this system.")

    dtype = getattr(torch, args.dtype)
    device = torch.device("mps")
    torch.manual_seed(0)

    gemm_shapes = [
        GemmShape(512, 512, 512),
        GemmShape(1024, 1024, 1024),
        GemmShape(2048, 2048, 2048),
    ]
    sdpa_configs = [
        SdpaConfig(1, 8, 512, 64),
        SdpaConfig(2, 16, 1024, 64),
    ]

    print(f"[mps_microbench] Device: {device}, dtype={dtype}")
    print(f"[mps_microbench] Env: {format_env_summary()}")

    if "gemm" in args.benchmarks:
        gemm_results = run_gemm_benchmarks(device, dtype, gemm_shapes)
        print("\nGEMM timings (average seconds):")
        for shape, seconds in gemm_results:
            tflops = (2 * shape.m * shape.n * shape.k) / (seconds * 1e12)
            print(f"  {shape.m}x{shape.k} @ {shape.k}x{shape.n}: {seconds:.5f}s ({tflops:.2f} TFLOP/s)")

    if "sdpa" in args.benchmarks:
        sdpa_results = run_sdpa_benchmarks(device, dtype, sdpa_configs)
        print("\nSDPA timings (average seconds):")
        for config, seconds in sdpa_results:
            tokens = config.batch * config.heads * config.seq_len * config.head_dim
            print(f"  B{config.batch} H{config.heads} L{config.seq_len} D{config.head_dim}: {seconds:.5f}s ({tokens/seconds/1e9:.2f} GTokens/s)")


if __name__ == "__main__":
    main()

Fast-math GEMM run (env PYTORCH_MPS_FAST_MATH=1 python3 scripts/mps_microbench.py --benchmarks gemm):
• 512³: 0.00093 s (0.29 TFLOP/s)
• 1024³: 0.00118 s (1.82 TFLOP/s)
• 2048³: 0.00317 s (5.42 TFLOP/s)

Metal-only SDPA run (env PYTORCH_MPS_PREFER_METAL=1 python3 scripts/mps_microbench.py --benchmarks sdpa):
• B1/H8/L512/D64: 0.00092 s (0.28 GTokens/s)
• B2/H16/L1024/D64: 0.00380 s (0.55 GTokens/s)

Measured averages (float16, 8 iterations GEMM / 6 SDPA):
• Vanilla: GEMM 2048³ ≈ 0.00254 s (6.76 TFLOP/s); SDPA B2/H16/L1024/D64 ≈ 0.00317 s (0.66 GTokens/s).
• Fast Math: GEMM 2048³ ≈ 0.00240 s (7.16 TFLOP/s); SDPA ≈ 0.00180 s (1.16 GTokens/s).
• Prefer Metal: GEMM 2048³ ≈ 0.00973 s (1.77 TFLOP/s); SDPA ≈ 0.00172 s (1.22 GTokens/s).

…owing combination of MPS Fast mode and Metal SDPA kernels for optimal performance balance
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167424

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6c820d9 with merge base b91a2ab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Nov 9, 2025
@linux-foundation-easycla
Copy link

CLA Missing ID CLA Not Signed

@bghira
Copy link
Author

bghira commented Nov 9, 2025

this provides overall 3-5x speedup on 2048**2 shape when all three variables are enabled with this PR built, since the MPS path is used for GEMM and Metal kernels are used for SDPA.

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants