Skip to content

[mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0#4156

Merged
danielvegamyhre merged 8 commits intomainfrom
cute-and-refactor
Mar 25, 2026
Merged

[mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0#4156
danielvegamyhre merged 8 commits intomainfrom
cute-and-refactor

Conversation

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre commented Mar 24, 2026

Summary

  • Created 2d version of the 3d weight quantization kernel from @alexsamardzic. It accepts 2d tensor in row major layout of shape (M,K) and each CTA iterates along K
  • Added tests and benchmark scripts comparing it to the current 2 stage approach (quant kernel -> blocked layout kernel)

Tests

  • pytest test/prototype/moe_training/test_kernels.py -s -k 2d_numerics
  • pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -k dq_fwd_bwd
  • pytest test/prototype/moe_training/test_training.py -s

Benchmarks (writing scales directly to blocked layout for tcgen05.mma)

  • 1.29x to 2.88x faster than existing for local batch size = 1 (2.8 to 4.2 tb/s)
  • 1.08x to 1.18x faster than existing for local batch size = 4 (4.4 to 5.8 tb/s)
  • 1.10x to 1.24x faster than exisitng for local batch size = 16 (5.8 to 6.4 tb/s)
input_shape     scaling_mode      num_groups    cutedsl_blocked_us    triton+rearrange_us  speedup      cutedsl_gbps    triton+rearrange_gbps
--------------  --------------  ------------  --------------------  ---------------------  ---------  --------------  -----------------------
(8192, 2048)    floor                      8                 18.03                  51.14  2.84x              2820.3                    994.5
(8192, 2048)    rceil                      8                 17.47                  50.37  2.88x              2910.7                   1009.7
(8192, 7168)    floor                      8                 42.02                  55.3   1.32x              4236.4                   3219
(8192, 7168)    rceil                      8                 42.02                  54.3   1.29x              4236.4                   3277.8
(32768, 2048)   floor                      8                 46.14                  54.27  1.18x              4408.5                   3748.2
(32768, 2048)   rceil                      8                 46.08                  54.27  1.18x              4414.6                   3748.2
(32768, 7168)   floor                      8                125.95                 136.38  1.08x              5652.8                   5220.4
(32768, 7168)   rceil                      8                121.89                 136.19  1.12x              5841.3                   5227.8
(131072, 2048)  floor                      8                140.32                 154.69  1.10x              5798.9                   5260.2
(131072, 2048)  rceil                      8                136.22                 154.69  1.14x              5973.2                   5260.2
(131072, 7168)  floor                      8                457.76                 556.1   1.21x              6221.5                   5121.3
(131072, 7168)  rceil                      8                443.46                 547.84  1.24x              6422.1                   5198.5

Next steps

  • Fused per-group padding into quantization kernel for the "non-HybridEP" and "non-EP" cases. I have been working on this but it is difficult to get good performance so far.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 24, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4156

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6464bbd with merge base 1b2682d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 24, 2026
@danielvegamyhre danielvegamyhre added the module: training quantize_ api training flow label Mar 24, 2026
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000

if mode == "dim0":
if mode == "memcpy":
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: added this so we can get a sense of best achievable bandwidth utilization for small shapes like (4096, 7168) which memcpy gets 4.5 tb/s. this is useful for cases like seq_len=4096, microbatch_size=1

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 24, 2026

Can we reuse some of the code between the two?

@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

Can we reuse some of the code between the two?

Yeah, will do. Will ping when it's done

# Config format:
# (compute_warps, tile_m, tile_k, k_tiles_per_cta)
_CUTEDSL_CONFIGS = {
"bf16_default": (4, 128, 32, 4),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: got better perf with slightly different compute_warps and tiles_per_cta for 2d vs 3d

@danielvegamyhre danielvegamyhre force-pushed the cute-and-refactor branch 2 times, most recently from dcc909c to 1cfb7b1 Compare March 24, 2026 16:28
@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

@drisspg ok i extracted out various shared helpers into cute_utils.py and import them in both 2d/3d quant kernels

Comment on lines +652 to +658
# SM >= 100 (Blackwell and beyond, including consumer SM12x and
# SM13x): use tcgen05.CtaGroup.ONE for the optimised single-CTA
# Blackwell TMA load path.
if cutlass.const_expr(IS_BLACKWELL_VALUE):
g2s_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)
else:
g2s_op = cpasync.CopyBulkTensorTileG2SOp()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the end, I enabled the 3D kernel for Blackwell only (for easier maintenance, and as GEMM kernels, that these kernels are preparing data for, are Blackwell only), so we can remove IS_BLACKWELL_VALUE and keep only branches where this value is true, for both kernels.

@alexsamardzic
Copy link
Copy Markdown
Collaborator

i extracted out various shared helpers into cute_utils.py and import them in both 2d/3d quant kernels

I'm wondering, is it possible to extract more, for example _issue_tma_store maybe could have a constexpr arg IS_INPUT_2D and then the part of the body that differs between the two coded like:

if cutlass.const_expr(IS_INPUT_2D):
    sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 1)
    gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 1)
else:
    sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 2)
    gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 2)

and alike. Maybe the kernel itself should be split into more @cute.jit functions, so that there is less duplication between the two. (I know it's boring to do this, but it may make implementing FP4 kernels simpler too.)

@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

danielvegamyhre commented Mar 24, 2026

i extracted out various shared helpers into cute_utils.py and import them in both 2d/3d quant kernels

I'm wondering, is it possible to extract more, for example _issue_tma_store maybe could have a constexpr arg IS_INPUT_2D and then the part of the body that differs between the two coded like:

if cutlass.const_expr(IS_INPUT_2D):
    sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 1)
    gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 1)
else:
    sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 2)
    gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 2)

and alike. Maybe the kernel itself should be split into more @cute.jit functions, so that there is less duplication between the two. (I know it's boring to do this, but it may make implementing FP4 kernels simpler too.)

Yeah I considered extracting the tma load/store but as you noted there are some 2d/3d specific differences and I didn't want to "over-generify / over-abstract" the code... Either way could be reasonable I guess

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 24, 2026

can you explain the differences between 2d and 3d impls such that we can't have these be the same kernel but with different layouts?

genuine Q;

like is one for the packaged inputs the other for the 3d weights, does outer batch dim + different layout not handle this

@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

danielvegamyhre commented Mar 24, 2026

can you explain the differences between 2d and 3d impls such that we can't have these be the same kernel but with different layouts?

genuine Q;

like is one for the packaged inputs the other for the 3d weights, does outer batch dim + different layout not handle this

the main reason is that the 2d case has to handle (1) dynamic group sizes and (2) perform per group padding. This isn't done in the kernel yet, I am breaking it into logical steps of "get correct numerics and good perf for simple 2d case" and then once we have that solid starting point, work on adding this, rather than try to tackle it all at once.

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 24, 2026

@claude review this pr please

@claude
Copy link
Copy Markdown

claude bot commented Mar 24, 2026

Claude finished @drisspg's task in 3m 14s —— View job


  • Read all changed files and understand the changes
  • Review cute_utils.py (shared utilities)
  • Review cutedsl_quantize_2d.py (new 2D kernel)
  • Review changes to cutedsl_quantize_3d.py (refactored 3D kernel)
  • Review quant.py integration
  • Review tests and benchmarks
  • Post review feedback

Review Summary

Overall this is a solid contribution that adds a 2D quantization kernel with good benchmark results and test coverage. The shared utility extraction into cute_utils.py is clean. Below are issues I found, ordered by severity.


Potential Bug: Missing fence_view_async_shared() in 2D kernel

The 3D kernel (cutedsl_quantize_3d.py:506-512) calls cute.arch.fence_view_async_shared() after mbarrier_wait in both the multi-stage and single-stage paths:

# 3D kernel pattern (lines 506-512):
if warp_idx >= 1 and warp_idx <= compute_warps:
    cute.arch.mbarrier_wait(tma_mbar_ptr, tma_phase)
    cute.arch.fence_view_async_shared()  # <-- present in 3D

The 2D kernel (cutedsl_quantize_2d.py:521-522) omits this fence:

# 2D kernel (lines 521-522):
if warp_idx >= 1 and warp_idx <= compute_warps:
    cute.arch.mbarrier_wait(tma_mbar_ptr, tma_phase)
    # fence_view_async_shared() is MISSING

This fence ensures that shared memory writes from TMA are visible to the consumer warps. Its absence could lead to stale reads from shared memory in certain timing conditions. The fact that tests pass doesn't guarantee this is safe -- the hazard may only manifest under specific timing/load conditions. I'd recommend adding this fence to match the 3D kernel's pattern.

Additionally, the 3D kernel has a structural difference: the mbarrier_wait is placed inside the const_expr(STAGE_COUNT > 1 ...) / else blocks (lines 506-512), while the 2D kernel places it unconditionally after the prefetch block (line 521). This means in the 2D kernel's multi-stage path, the consumer warp does the mbarrier_wait outside the prefetch constexpr block rather than inside it. This may work functionally but diverges from the 3D pattern in a way that could cause issues. Worth aligning for consistency.

Fix this →


Agree with @alexsamardzic: Remove IS_BLACKWELL_VALUE branches

As noted in the review comment, since the 3D kernel was simplified to Blackwell-only, the same should be done for the 2D kernel. Both kernels still carry IS_BLACKWELL_VALUE as a compile-time parameter and branch on it in multiple places (e.g., cutedsl_quantize_2d.py:103,655,703 and cutedsl_quantize_3d.py:103,624,676), but the host-side code at quant.py:805-806 already restricts to torch.cuda.get_device_capability()[0] == 10. Removing the non-Blackwell branches would simplify the code and reduce compilation variants.


Code Duplication Opportunities

The 2D and 3D kernels share a significant amount of nearly identical code that goes beyond what was extracted into cute_utils.py. Specifically:

  1. _quantize_store_chunk -- Both kernels have identical logic: vals_chunk.load() * inv_scale, clamp if not RCEIL, convert to FP8, store to shared memory. The only difference is the shared memory store call signature (_store_q_fp8_chunk takes different args for 2D vs 3D indexing).

  2. _load_vals_chunk_full / _load_vals_chunk_tail -- Identical between 2D and 3D.

  3. _quantize_store_full / _quantize_store_tail -- Same loop structure, just calling the above differently.

  4. _issue_tma_load / _issue_tma_store -- Differ only in group_modes(sOUT_tile, 0, 1) vs group_modes(sOUT_tile, 0, 2), which could be parameterized as @alexsamardzic suggested.

  5. Pipeline infrastructure -- The stage management, mbarrier init, staged layout setup, and tile_step loop structure are nearly line-for-line identical.

I understand the concern about over-abstraction, but the current level of duplication means any bug fix (like the missing fence above) needs to be applied in two places. At minimum, _load_vals_chunk_full, _load_vals_chunk_tail, and _quantize_store_chunk could be moved into cute_utils.py.


Unrelated Default Change in mx_formats/kernels.py

The PR changes the default scaling_mode from "floor" to "rceil" in mxfp8_quantize_cuda() (torchao/prototype/mx_formats/kernels.py:1027,1167). This is a behavior change to an existing API that is unrelated to the 2D kernel addition. If intentional, it should be called out in the PR description. If not, it should be reverted to keep the PR focused.


Minor Issues

  1. _select_cutedsl_config has an unused K parameter (cutedsl_quantize_2d.py:43-46): The parameter is explicitly del-ed. If it's intended for future tuning, a comment would help; otherwise, just remove it.

  2. _store_scales_vectorized fallback with runtime num_scales (cutedsl_quantize_2d.py:198): In the tail path, num_scales is a runtime cutlass.Int32 value passed to range(). This should work in CuTeDSL (compiles to a loop), but it's worth verifying that the compiler handles this efficiently vs. the constexpr-unrolled paths.

  3. Test parametrization is extensive but could miss edge cases: The 2D numerics test (test_kernels.py:440-488) tests M and K values (32, 1536, 5120, 7168, 8192) but doesn't include cases where K is not divisible by tile_k * k_tiles_per_cta (e.g., K=96 with tile_k=32, k_tiles_per_cta=4 would trigger the non-full-tiles path with is_full_k_tiles=False). Consider adding a test case like M=128, K=96 to exercise the tail handling.

  4. Benchmark uses generate_jagged_offs with multiple_of=128 (bench_cutedsl_quantize_2d.py:98): The TODO comment explains this is because per-group padding isn't supported yet, but the benchmark comparison with triton_to_mxfp8_dim0 + mx_block_rearrange_2d_M_groups_cuda is slightly apples-to-oranges since the Triton path handles arbitrary groups while the CuTeDSL path requires 128-aligned groups.

  5. cast_bench.py new modes: The variables are named y_d1, s_d1 in the dim0_mxfp8_cutedsl_2d_* modes (cast_bench.py:479,503) which is misleading since these are dim0 operations. Consider renaming to y_d0, s_d0 for consistency with the mode name.


CI Failures

The 8 CI failures all show RuntimeError: Command docker exec ... failed with exit code 2, which look like infrastructure/Docker issues rather than code problems. These don't appear to be related to the PR changes.


Summary

The kernel implementation and benchmarks look good. The main concern is the missing fence_view_async_shared() which should be addressed before merge. The other items (removing IS_BLACKWELL branches, reducing duplication) are improvements worth making but less urgent.

@danielvegamyhre danielvegamyhre force-pushed the cute-and-refactor branch 2 times, most recently from 98bf017 to 203a0e7 Compare March 24, 2026 22:35
@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

danielvegamyhre commented Mar 24, 2026

ok @drisspg @alexsamardzic latest changes:

  • claude's first comment/"bug" is wrong in this case, ignored (explicit fence isn't necessary after mbarrier completes, it is implicit). i think it was looking at an old version of 3d kernel that had that pattern and got confused, not sure.
  • removed IS_BLACKWELL branching and only kept Blackwell path, we don't need to support other architectures, we have emulated mode for that
  • unified more helper functions in cute_utils.py and import to use in 2d/3d kernels
  • removed unused K param from _select_cutedsl_config. also updated to use actual torch.dtype instead of a string str(torch.dtype) - was a bit odd
  • renamed functions to make data movement more obvious/clear (register -> smem, reg -> gmem etc)
  • added docstrings to functions to clarify inputs/outputs and functionality
  • fixed var names in cast_bench.py to clarify we quant along dim0

@danielvegamyhre danielvegamyhre force-pushed the cute-and-refactor branch 5 times, most recently from 18cb148 to f6887bf Compare March 25, 2026 00:24
@danielvegamyhre
Copy link
Copy Markdown
Contributor Author

@drisspg @alexsamardzic i also confirmed the auto-vectorization is working, since the code as written has loops of 4 byte stores etc but confirmed the SASS shows 16 byte shared load/stores (LDS.128/STS.128).

we do have bank conflicts on both smem read/writes that could potentially be addressed in a follow up (e.g. for stores, write to smem with manual swizzle and do tma store with this swizzle set). not sure if worth the lift though, sometimes it ends up being as fast or faster to just let the serialization into sequential access happen...

Copy link
Copy Markdown
Collaborator

@alexsamardzic alexsamardzic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants