[rocm] scaled_grouped_mm support gfx942 fp8 data type#3802
[rocm] scaled_grouped_mm support gfx942 fp8 data type#3802xiaobochen-amd wants to merge 11 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3802
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 0bd013d with merge base 01d3a2d ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@xiaobochen-amd can you share some perf numbers/benchmarks comparing to bf16 baseline? microbenchmarks or e2e training in torchtitan? |
Hello, here is some of the becnhmarking data from torchtitan using @xiaobochen-amd PR. testing e2e performance FP8 grouped gemm is about 10%~ behind on TPS compared to FP16, but that is something we are looking into.
|
|
@alex-minooka @xiaobochen-amd fyi this part of the codebase is going through a substantial refactor, so we need to pause landing this until that is complete (in 1-2 days or so). Sorry for the inconvenience. #3862 |
|
hey @xiaobochen-amd go ahead and rebase and land if you want, that refactor PR is landed now |
@danielvegamyhre Ok, I’ll handle this PR soon. |

Old PR: #3540