[mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0#4156
[mxfp8 training] add cutedsl kernel for mxfp8 quantation along dim0#4156danielvegamyhre merged 8 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4156
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6464bbd with merge base 1b2682d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f0e59ff to
396af61
Compare
| x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 | ||
|
|
||
| if mode == "dim0": | ||
| if mode == "memcpy": |
There was a problem hiding this comment.
note: added this so we can get a sense of best achievable bandwidth utilization for small shapes like (4096, 7168) which memcpy gets 4.5 tb/s. this is useful for cases like seq_len=4096, microbatch_size=1
|
Can we reuse some of the code between the two? |
Yeah, will do. Will ping when it's done |
| # Config format: | ||
| # (compute_warps, tile_m, tile_k, k_tiles_per_cta) | ||
| _CUTEDSL_CONFIGS = { | ||
| "bf16_default": (4, 128, 32, 4), |
There was a problem hiding this comment.
note: got better perf with slightly different compute_warps and tiles_per_cta for 2d vs 3d
dcc909c to
1cfb7b1
Compare
|
@drisspg ok i extracted out various shared helpers into cute_utils.py and import them in both 2d/3d quant kernels |
| # SM >= 100 (Blackwell and beyond, including consumer SM12x and | ||
| # SM13x): use tcgen05.CtaGroup.ONE for the optimised single-CTA | ||
| # Blackwell TMA load path. | ||
| if cutlass.const_expr(IS_BLACKWELL_VALUE): | ||
| g2s_op = cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) | ||
| else: | ||
| g2s_op = cpasync.CopyBulkTensorTileG2SOp() |
There was a problem hiding this comment.
At the end, I enabled the 3D kernel for Blackwell only (for easier maintenance, and as GEMM kernels, that these kernels are preparing data for, are Blackwell only), so we can remove IS_BLACKWELL_VALUE and keep only branches where this value is true, for both kernels.
I'm wondering, is it possible to extract more, for example if cutlass.const_expr(IS_INPUT_2D):
sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 1)
gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 1)
else:
sOUT_for_tma_partition = cute.group_modes(sOUT_tile, 0, 2)
gOUT_for_tma_partition = cute.group_modes(gOUT_tile, 0, 2)and alike. Maybe the kernel itself should be split into more |
Yeah I considered extracting the tma load/store but as you noted there are some 2d/3d specific differences and I didn't want to "over-generify / over-abstract" the code... Either way could be reasonable I guess |
|
can you explain the differences between 2d and 3d impls such that we can't have these be the same kernel but with different layouts? genuine Q; like is one for the packaged inputs the other for the 3d weights, does outer batch dim + different layout not handle this |
the main reason is that the 2d case has to handle (1) dynamic group sizes and (2) perform per group padding. This isn't done in the kernel yet, I am breaking it into logical steps of "get correct numerics and good perf for simple 2d case" and then once we have that solid starting point, work on adding this, rather than try to tackle it all at once. |
|
@claude review this pr please |
|
Claude finished @drisspg's task in 3m 14s —— View job
Review SummaryOverall this is a solid contribution that adds a 2D quantization kernel with good benchmark results and test coverage. The shared utility extraction into Potential Bug: Missing
|
1cfb7b1 to
2de6463
Compare
2de6463 to
b1f7445
Compare
98bf017 to
203a0e7
Compare
|
ok @drisspg @alexsamardzic latest changes:
|
18cb148 to
f6887bf
Compare
f6887bf to
416fa3e
Compare
|
@drisspg @alexsamardzic i also confirmed the auto-vectorization is working, since the code as written has loops of 4 byte stores etc but confirmed the SASS shows 16 byte shared load/stores ( we do have bank conflicts on both smem read/writes that could potentially be addressed in a follow up (e.g. for stores, write to smem with manual swizzle and do tma store with this swizzle set). not sure if worth the lift though, sometimes it ends up being as fast or faster to just let the serialization into sequential access happen... |
Summary
Tests
pytest test/prototype/moe_training/test_kernels.py -s -k 2d_numericspytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -k dq_fwd_bwdpytest test/prototype/moe_training/test_training.py -sBenchmarks (writing scales directly to blocked layout for tcgen05.mma)
Next steps