Skip to content

Conversation

@agolajko
Copy link
Contributor

@agolajko agolajko commented Nov 14, 2025

Summary

As discussed with @vkuzo in #3290

Replaced torch._scaled_mm with torch.nn.functional.scaled_mm and ran the two benchmark (bench_1x128_128x1_gemms.py and bench_1x128_128x128_gemms.py) scripts from here

Results on an H100 with the following setup:

Torchao: 0.15.0+git1fbc5f6a5
Python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
PyTorch: 2.10.0.dev20251113+cu129
CUDA: 12.9
CuDNN: 91002

[OS]
OS: Linux 6.8.0-60-generic
Distribution: Ubuntu 24.04.3 LTS
570.133.20, NVIDIA H100 PCIe, 9.0

# python bench_1x128_128x1_gemms.py
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:02<00:00, 31.40s/it]
    M     N     K  out_dtype         bf16_mm_us    fp8_triton_us    fp8_scaled_mm_us    bf16 tflops/sec    triton tflops/sec    scaled_mm tflops/sec
-----  ----  ----  --------------  ------------  ---------------  ------------------  -----------------  -------------------  ----------------------
16640  5120  8192  torch.bfloat16       3223.73          4511.23             2405.09            432.997              309.42                  580.38
16640  8192  5120  torch.bfloat16       3243.3           4708.93             2404.93            430.385              296.429                 580.418

# python bench_1x128_128x128_gemms.py
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00,  8.44s/it]
    M     N     K  out_dtype         bf16_mm_us    fp8_triton_us    fp8_scaled_mm_us    bf16 tflops/sec    triton tflops/sec    scaled_mm tflops/sec
-----  ----  ----  --------------  ------------  ---------------  ------------------  -----------------  -------------------  ----------------------
16640  5120  8192  torch.bfloat16       3351.36          4665.82             2170.48            416.507              299.168                 643.113
16640  8192  5120  torch.bfloat16       3466.82          4681.14             2286.5             402.636              298.189                 610.482

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3342

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2025
@agolajko agolajko mentioned this pull request Nov 14, 2025
56 tasks
@agolajko
Copy link
Contributor Author

@vkuzo lmk if you have comments or suggestions for the above changes

@agolajko
Copy link
Contributor Author

Re linted

@danielvegamyhre danielvegamyhre self-requested a review November 23, 2025 13:35
@danielvegamyhre
Copy link
Contributor

thanks for updating this, running CI now

@agolajko agolajko force-pushed the feat/bench_1x128_scaled_mm_fix branch from d4052c4 to 0cd581e Compare November 24, 2025 00:13
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Nov 24, 2025
@danielvegamyhre danielvegamyhre merged commit 4c16ab8 into pytorch:main Nov 24, 2025
16 of 20 checks passed
namgyu-youn pushed a commit to namgyu-youn/ao that referenced this pull request Dec 19, 2025
….nn.functional.scaled_mm (pytorch#3342)

* replaced torch._scaled_mm with torch.nn.functional.scaled_mm

* lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants