Skip to content

Build mxfp4 kernel for sm120a #2285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 21, 2025
Merged

Build mxfp4 kernel for sm120a #2285

merged 12 commits into from
Jun 21, 2025

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented May 31, 2025

Update (2025/06/14)

Added benchmark script, courtesy to @drisspg. Perhaps you can run it for B200 as well

At 575W (default)

575w_mx_8192_8192_20250614_110313

At 400W - this is just for me since I'm running this card at 400W

400w_mx_8192_8192_20250614_110032

For reference (5090 is last column)

image

At stock power limit, the card can comfortably reach speed of light. The interesting bit is in MX-FP8 -> it looks like the kernel from torch._scaled_mm() (I guess it's from CuDNN/CuBLAS?) is using FP16 accumulate 🤔

Update (2025/06/11)

I narrowed down the issue to template - if the kernel is inside a templated function, even if I don't use any template arguments, I will get the runtime error below (cudaFuncSetAttribute() returned error: invalid resource handle). It might be an issue with cutlass or my environment (nvcc version, compiler...).

Hence, the solution is to create a separate source file for sm120a, without any templated functions. When we support nvfp4 in the future, we can either manually duplicate the code again, use macro, or have a python script to codegen the cutlass kernel creation.

Other details of this PR:

  • Add sm120a extension
  • Modify torch library loading logic: For cutlass kernels with architecture-specific targets (smXYa), it only loads those matching the current GPU's compute capability. The limitation of this code is that it won't work correctly for multi-GPU setup with different compute capabilities.

Other alternatives that I have considered for the torch library loading logic:

  • Basically we need a runtime check to select sm100a or sm120a kernel
  • Due to setuptools.Extension's limitation, sm100a and sm120a kernels must stay in separate shared library files. This eliminates the option of doing runtime check in C++.
  • Hence, the runtime check must be in Python. I think of 2 options
      1. Name 2 different ops (e.g. mx_fp4_bf16_sm100a and mx_fp4_bf16_sm120a), and dispatch the correct op in Python
      1. Use the same op name, but only load the torch library corresponding to current GPU at startup -> I go with this approach

Original (2025/05/31)

Just making some quick changes here to see if I can build mxfp4 kernel on 5090 (sm120). Eventually this will be put under torchao._C_cutlass_120a?

Setting -DCUTLASS_DEBUG_TRACE_LEVEL=1 so I can see debug trace.

To build (using torch==2.8.0.dev20250530+cu128)

TORCH_CUDA_ARCH_LIST=12.0a uv pip install -e . -v --no-build-isolation

Running pytest test/prototype/mx_formats/test_mx_mm.py -v

/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:244    workspace_bytes: 0
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:312  GemmUniversal::initialize() - workspace 0, stream: null
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:201  to_underlying_arguments():
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:214    WARNING: Arguments do not include a valid SM count.
  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:218  to_underlying_arguments(): Setting persistent grid SM count to 170
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:224    WARNING: Arguments do not include a valid max cluster count.
  For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters.
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336    Setting smem size to 101376
/home/thien/code/ao/third_party/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:343    cudaFuncSetAttribute() returned error: invalid resource handle

cudaFuncSetAttribute() returned error: invalid resource handle means that the function is invalid? https://github.com/NVIDIA/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/device/gemm_universal_adapter.h#L338, which is quite strange...

For reference, I can build and run the example from Cutlass here https://github.com/NVIDIA/cutlass/blob/v3.9.2/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu. The changes in this PR has been taken from this example. When building with CUTLASS_DEBUG_TRACE_LEVEL=1, there are also warnings in sm90_gemm_tma_warpspecialized_cooperative.hpp, so that is probably not the issue.

@drisspg

cc @alexsamardzic in case you faced this error with Cutlass before

Copy link

pytorch-bot bot commented May 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2285

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit ae5f55a with merge base eb86177 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 31, 2025
@drisspg
Copy link
Contributor

drisspg commented May 31, 2025

The first thing that comes to mind is that example is doing NVfp4 where all our recipes are doing MXfp4, e.g. https://github.com/pytorch/ao/pull/2285/files#diff-e155558499c3b1fbab1b5d3b60f032bf1e636908a8ef50a1de33bff518107019R240-R241 needs to change as well. For inference we have MXFP8 and MXFP4 support I am planning to add an NVFP4 scaling recipe next, that being said I would imagine that MXFP4 is supported on 5090..

cc @syed-ahmed

@gau-nernst
Copy link
Collaborator Author

I noticed that as well

  • Changing the torchao kernel to nvfp4 results in the same error
  • Changing the cutlass example to mxfp4 still works

😭

@syed-ahmed
Copy link
Contributor

Per cutlass docs, I believe MXFP4 is supported in 5090: https://github.com/NVIDIA/cutlass/blob/9d165a3b8ef446a7ff3db198413f82bcb83f46fe/media/docs/cpp/blackwell_functionality.md#blackwell-sm120-gemms

However note the section that talks about the differences with sm100. So it's possible we need more changes to the kernel in torch ao. Also what CUDA version are you using? I'd assume you'd need a fairly recent CUDA version. I'll try to guide more next week.

@gau-nernst
Copy link
Collaborator Author

@syed-ahmed I'm using CUDA 12.9

The strange thing is that the cutlass example works, but the one in torchao doesn't. I carefully compared the two, and I don't spot any difference in the template arguments.

@syed-ahmed
Copy link
Contributor

How about the test? Are the inputs similar to the cutlass example?

@gau-nernst gau-nernst force-pushed the sm120 branch 2 times, most recently from 563fc7c to 0f2f3af Compare June 11, 2025 15:22
@gau-nernst gau-nernst added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 11, 2025
@gau-nernst gau-nernst marked this pull request as ready for review June 11, 2025 15:43
@gau-nernst gau-nernst requested a review from drisspg June 11, 2025 22:33
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good, could you add a test to test/prototype/mx_formats/test_mx_mm.py even if it wound't be exercised in ci, As well if you have any perf numbers that would be great

@gau-nernst gau-nernst force-pushed the sm120 branch 2 times, most recently from 2da21d1 to e8738bc Compare June 14, 2025 03:01
@gau-nernst gau-nernst requested a review from drisspg June 14, 2025 03:36
plot_tflops_comparison(df, save_path)


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we unify this code with https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_matmul.py instead? I know the path says float8 but it would be good to have it all in one place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added mxfp4 to that script. You can review it. Thank you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any changes you want to make? If not, I will merge 🙏

A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
if use_fp4:
A = torch.zeros(M, K // 2, device=device, dtype=torch.int8).view(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just now looking at this but the performance on Nvidia w/ zeros vs non zero filled data is very very different

I think in a follow up PR we should make the zero filled an option vs randn distributed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can even be in this PR, I'd vote for randn as default

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Should there be an option to choose between randn and zeros? Or just use randn for everything
  2. For FP4, should I do randn in FP32/BF16 then convert to FP4, or just do randint(0, 255, dtype=uint8).view(float4_e2m1fn_x2)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ I think the option for zeros can come later

  1. I would do the first randn in Higher precision then cast

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind taking a final look then merge? The failing tests seem unrelated.

@drisspg drisspg merged commit e73a142 into pytorch:main Jun 21, 2025
33 of 35 checks passed
@gau-nernst gau-nernst deleted the sm120 branch June 22, 2025 01:33
@jerryzh168
Copy link
Contributor

sorry this breaks internal CI, we'' revert first to unblock diff train

clang++: warning: argument unused during compilation: '-pie' [-Wunused-command-line-argument]
ld.lld: error: duplicate symbol: torchao::mx_fp4_bf16(at::Tensor, at::Tensor, at::Tensor, at::Tensor)
>>> defined at __stripped__/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels_sm100a.cu.pic.stripped.o:(torchao::mx_fp4_bf16(at::Tensor, at::Tensor, at::Tensor, at::Tensor)) in archive buck-out/v2/gen/fbcode/38bba46f0633e4ed/pytorch/ao/___C__/lib_C.stripped.pic.a
>>> defined at __stripped__/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels_sm120a.cu.pic.stripped.o:(.text.unlikely._ZN7torchao11mx_fp4_bf16EN2at6TensorES1_S1_S1_+0x0) in archive buck-out/v2/gen/fbcode/38bba46f0633e4ed/pytorch/ao/___C__/lib_C.stripped.pic.a
clang++: error: linker command failed with exit code 1 (use -v to see invocation)

jerryzh168 added a commit that referenced this pull request Jun 24, 2025
Gasoonjia pushed a commit that referenced this pull request Jun 24, 2025
Revert "Build mxfp4 kernel for sm120a (#2285)"

This reverts commit e73a142.
@gau-nernst
Copy link
Collaborator Author

Yea I guess we need to do something like this, if we don't use cmake to build each source file with different flags

https://github.com/pytorch/pytorch/blob/40a785103cf94a1dbc3e0e43d1ed6c41fb60bedb/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L4

xiaowangintel pushed a commit to xiaowangintel/ao that referenced this pull request Jun 24, 2025
xiaowangintel pushed a commit to xiaowangintel/ao that referenced this pull request Jun 24, 2025
Revert "Build mxfp4 kernel for sm120a (pytorch#2285)"

This reverts commit e73a142.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants