Skip to content

Add CuTeDSL kernel for 3D tensor quantization to MXFP8#4090

Merged
danielvegamyhre merged 1 commit intomainfrom
asamardzic/cutedsl-quantize-3d
Mar 20, 2026
Merged

Add CuTeDSL kernel for 3D tensor quantization to MXFP8#4090
danielvegamyhre merged 1 commit intomainfrom
asamardzic/cutedsl-quantize-3d

Conversation

@alexsamardzic
Copy link
Copy Markdown
Collaborator

No description provided.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4090

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 920db9c with merge base e40a817 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 16, 2026
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 16, 2026

The kernel works on everything Hopper and above. It supports both floor and rceil rounding.

To test:

pytest test/prototype/moe_training/test_kernels.py -k cuda_mx_dim1_3d_numerics

To benchmark:

for shape in \
  "8 5120 5120" \
  "8 7168 5120" \
  "8 7168 7168" \
  "8 8192 8192"
do
  read -r E N K <<< "$shape"
  python benchmarks/mx_formats/mxfp8_backend_bench.py \
    --dtype bf16 \
    --scaling-mode floor \
    --backends cuda,cutedsl \
    --check-results \
    --stage-count 2 \
    --E "$E" --N "$N" --K "$K"
done

Benchmarking results:

CC: (10, 0)
shape=(E,N,K)=(8,5120,5120) dtype=torch.bfloat16 scaling_mode=floor stage_count=2
[cuda      ]    0.131 ms    4.847 TB/s   y_stride=(26214400, 1, 5120) s_shape=(8, 160, 5120)
[cutedsl   ]    0.123 ms    5.165 TB/s   y_stride=(26214400, 1, 5120) s_shape=(8, 160, 5120)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)
GPU: NVIDIA B200
CC: (10, 0)
shape=(E,N,K)=(8,7168,5120) dtype=torch.bfloat16 scaling_mode=floor stage_count=2
[cuda      ]    0.180 ms    4.934 TB/s   y_stride=(36700160, 1, 7168) s_shape=(8, 224, 5120)
[cutedsl   ]    0.170 ms    5.232 TB/s   y_stride=(36700160, 1, 7168) s_shape=(8, 224, 5120)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)
GPU: NVIDIA B200
CC: (10, 0)
shape=(E,N,K)=(8,7168,7168) dtype=torch.bfloat16 scaling_mode=floor stage_count=2
[cuda      ]    0.250 ms    4.983 TB/s   y_stride=(51380224, 1, 7168) s_shape=(8, 224, 7168)
[cutedsl   ]    0.236 ms    5.283 TB/s   y_stride=(51380224, 1, 7168) s_shape=(8, 224, 7168)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)
GPU: NVIDIA B200
CC: (10, 0)
shape=(E,N,K)=(8,8192,8192) dtype=torch.bfloat16 scaling_mode=floor stage_count=2
[cuda      ]    0.326 ms    4.998 TB/s   y_stride=(67108864, 1, 8192) s_shape=(8, 256, 8192)
[cutedsl   ]    0.307 ms    5.304 TB/s   y_stride=(67108864, 1, 8192) s_shape=(8, 256, 8192)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)

Note: nvidia-cutlass-dsl and apache-tvm-ffi packages are now dependencies.

@danielvegamyhre danielvegamyhre added this to the MXFP8 Training milestone Mar 16, 2026
@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow mx moe labels Mar 16, 2026
@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch 3 times, most recently from c9d0f28 to 8412e87 Compare March 16, 2026 20:48
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great @alexsamardzic! I think the design makes sense, and appreciate your comments on what else you tried that didn't work to help me understand the reasoning behind the design.

I added some suggestions and questions, also if you wouldn't mind please check out this short blog on a cutedsl MXFP8 quantizer that they claim sustains 6+ TB/s for their shapes, and see if any ideas from it can be applied here: https://blog.fal.ai/chasing-6-tb-s-an-mxfp8-quantizer-on-blackwell/

q_fp8_vals4.store(q_fp8_vec4)

for j in range(chunk_vec):
sOUT_tile[0, sout_base + j, k_rel] = q_fp8_vals4[j]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(new to cutedsl, apologies if this comment is off base, let me know):

this will do 4 vectorized 4 byte rmem->smem writes, right? from what i can see this would appear to be simple row major layout, is that right?

we need the smem layout to be in ((32,4),4) blocked layout for tcgen05.mma usage, and write directly to that layout in smem before we do the TMA copy smem->gmem. if we store in simple row major format we pay the tax of an extra kernel dispatch doing this layout transformation on the scales, which is what we are currently doing and would like to avoid.

probably need to do scattered / uncoalesced 4 byte vectorized stores to smem for this, just make sure to avoid single byte shared stores / STS.U8

if you aren't familiar with this layout i added some diagrams to our recent MXFP8 training blog that help visualize it: https://pytorch.org/blog/mxfp8-training-for-moes-1-3x-training-speedup-vs-bf16-for-llama4-scout-on-gb200-cluster-using-torchao-and-torchtitan/

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quack code pointer that may be useful here:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to write one uint32 instead of four fp8, another 4% performance improvement here.

For the output layout update, see below.

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch 8 times, most recently from cedb175 to 6c4d33a Compare March 17, 2026 18:08
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 17, 2026

@danielvegamyhre: Thanks for all the comments, in particular for pointing to possible performance improvements. I've made updates according to all of your suggestions, and we're now getting pretty close to the 6.4 TB/s - here are latest rceil benchmarking results on B200 (note that these are for the default output layout, new variant with blocked output layout is significantly slower):

shape=(E,N,K)=(8,5120,5120) dtype=torch.bfloat16 scaling_mode=rceil stage_count=2
[cuda      ]    0.117 ms    5.431 TB/s   y_stride=(26214400, 1, 5120) s_shape=(8, 160, 5120)
[cutedsl   ]    0.105 ms    6.066 TB/s   y_stride=(26214400, 1, 5120) s_shape=(8, 160, 5120)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)

shape=(E,N,K)=(8,7168,5120) dtype=torch.bfloat16 scaling_mode=rceil stage_count=2
[cuda      ]    0.162 ms    5.494 TB/s   y_stride=(36700160, 1, 7168) s_shape=(8, 224, 5120)
[cutedsl   ]    0.144 ms    6.183 TB/s   y_stride=(36700160, 1, 7168) s_shape=(8, 224, 5120)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)

shape=(E,N,K)=(8,7168,7168) dtype=torch.bfloat16 scaling_mode=rceil stage_count=2
[cuda      ]    0.224 ms    5.565 TB/s   y_stride=(51380224, 1, 7168) s_shape=(8, 224, 7168)
[cutedsl   ]    0.201 ms    6.197 TB/s   y_stride=(51380224, 1, 7168) s_shape=(8, 224, 7168)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)

shape=(E,N,K)=(8,8192,8192) dtype=torch.bfloat16 scaling_mode=rceil stage_count=2
[cuda      ]    0.291 ms    5.592 TB/s   y_stride=(67108864, 1, 8192) s_shape=(8, 256, 8192)
[cutedsl   ]    0.262 ms    6.202 TB/s   y_stride=(67108864, 1, 8192) s_shape=(8, 256, 8192)
diff(cuda vs cutedsl): y_max_abs=0.0 s_max_abs=0.0
check(cuda vs cutedsl): PASS (atol=0.0)

I'll check the blog you pointed to for further improvement ideas.

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch from 6c4d33a to 168c66d Compare March 17, 2026 18:16
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 17, 2026

@claude also do a PR review pass here

@claude
Copy link
Copy Markdown

claude bot commented Mar 17, 2026

Claude finished @drisspg's task in 2m 56s —— View job


PR Review: Add CuTeDSL kernel for 3D tensor quantization to MXFP8

  • Read all changed files
  • Review the main kernel implementation (cutedsl_quantize_3d.py)
  • Review the integration layer (quant.py)
  • Review benchmarks and tests
  • Post comprehensive review feedback

Overall

This PR adds a CuTeDSL-based kernel for quantizing 3D tensors to MXFP8 format, targeting Hopper (SM90) and above. The kernel uses warp-specialized TMA with a producer/consumer model and achieves solid throughput numbers (up to ~6.2 TB/s on B200 with rceil). The architecture is well-documented with inline comments explaining design decisions that were tried and rejected.


Issues

1. Bug: Benchmark uses wrong backend (already flagged by @danielvegamyhre)

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py:143-148 — The CuTeDSL benchmark section calls benchmark_cuda_function_in_microseconds with backend="cutedsl", but on line 137 the warmup call correctly uses backend="cutedsl". However, looking at the original diff more carefully, the actual benchmark_cuda_function_in_microseconds call at line 143 does pass backend="cutedsl" — so this appears correct. But the warmup call on line 137 discards its results, which means if CuTeDSL compilation is slow, the first timed call includes JIT overhead in the benchmark.

Actually, looking more carefully at the benchmark: the warmup call on lines 137-142 runs the function once before benchmarking. This should trigger JIT compilation. The issue @danielvegamyhre flagged was that backend="cuda" might be accidentally used instead of backend="cutedsl" — but from the code I see, the backend="cutedsl" is correct at line 147. I'd defer to @danielvegamyhre's observation since they may have seen an earlier revision.

Fix this →

2. Unreachable code path in mxfp8_quantize_cuda_3d

torchao/prototype/moe_training/kernels/mxfp8/quant.py:1229-1232 — The final raise NotImplementedError after the if backend == "cutedsl" block is unreachable because the assert on line 1204 already constrains backend to ("cuda", "cutedsl"), and both are handled by if/if blocks above. This dead code has a misleading error message referencing "SM100+" twice (should mention SM90+ for CuTeDSL).

raise NotImplementedError(
    "mxfp8_quantize_cuda_3d is not implemented on this device. "
    "Expected SM100+ for CUDA C++ kernel, or SM100+ with CuTeDSL runtime."
    #                                        ^^^^^ should be SM90+
)

Fix this →

3. _mxfp8_cuda_kernels_available removed from __init__.py re-exports but still used

The test diff removes _mxfp8_cuda_kernels_available from the from torchao.prototype.moe_training.kernels.mxfp8 import block and moves it to import from .quant directly. However, looking at the current __init__.py (line 2), _mxfp8_cuda_kernels_available is still exported there. The test change is fine (importing from the canonical location), but this means there are now two import paths for it. Not a bug, just a potential consistency concern — downstream consumers may use either path.

4. _cutedsl_runtime_available unconditionally imported at module load

torchao/prototype/moe_training/kernels/mxfp8/quant.py:13-16 — The top-level import from torchao.prototype.moe_training.kernels.mxfp8.cutedsl_quantize_3d import (_cutedsl_runtime_available, mxfp8_quantize_cutedsl_3d) is unconditional. This means that even on systems where CuTeDSL is not needed (e.g., SM100+ with the CUDA C++ backend), importing quant.py will import the cutedsl_quantize_3d module. Currently cutedsl_quantize_3d.py itself only imports torch and torchao.utils at the top level (heavy deps like cutlass are deferred), so the overhead is minimal. Just noting this is intentional and works because the heavy imports are properly deferred.

5. blocked_scale_output not gated by SM100+ for Hopper

torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py:351 — The _store_scale method with BLOCKED_SCALE_OUTPUT computes the ((32,4),4) layout for tcgen05.mma usage. This blocked layout is specifically for Blackwell's tcgen05 MMA instructions. On Hopper (SM90), tcgen05 isn't available. The code currently allows blocked_scale_output=True on SM90 targets, which would produce a layout the hardware can't consume directly. Consider either:

  • Documenting that blocked output is only useful on SM100+
  • Gating blocked_scale_output=True behind is_blackwell

6. functools.cache on kernel compilation could accumulate entries

torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py:81@functools.cache is used on _compile_mxfp8_quantize_3d_cutedsl which takes 10 parameters. In practice the parameter space is small (a few dtype/mode/variant combos), so this is fine. Just noting that unlike @functools.lru_cache(maxsize=N), @functools.cache has unbounded growth. Not a real concern here given the parameter space.


Suggestions (non-blocking)

7. Test only covers ScaleCalculationMode.FLOOR

test/prototype/moe_training/test_kernels.py:381 — The test_cuda_mx_dim1_3d_numerics test only parametrizes scaling_mode with ScaleCalculationMode.FLOOR. Given that the rceil path has meaningfully different code (inline PTX for cvt.rp.satfinite.ue8m0x2.f32 on Blackwell, manual bit manipulation on Hopper), adding ScaleCalculationMode.RCEIL to the test matrix would improve coverage of the rceil scaling logic.

8. Consider validating N % block_size == 0 in mxfp8_quantize_cuda_3d dispatcher

torchao/prototype/moe_training/kernels/mxfp8/quant.py:1189-1232 — The top-level mxfp8_quantize_cuda_3d dispatcher doesn't validate that N % block_size == 0 before dispatching. The CuTeDSL path validates this at cutedsl_quantize_3d.py:872, but the CUDA C++ path (_mxfp8_quantize_cuda_3d_cpp) doesn't have a visible check. Moving the validation up to the dispatcher would give a clearer error message regardless of backend.

9. bench_quantize_3d.py will fail hard if CuTeDSL is unavailable

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py:136-148 — Unlike mxfp8_backend_bench.py which has try/except around each backend, bench_quantize_3d.py calls the CuTeDSL backend unconditionally. If run on a system without CuTeDSL packages, the whole benchmark script will crash rather than gracefully skipping. Consider wrapping the CuTeDSL section in a try/except or checking _mxfp8_cutedsl_kernels_available first.

10. The mxfp8_backend_bench.py lambda captures loop variable

benchmarks/mx_formats/mxfp8_backend_bench.py:79 — The lambda fn = lambda: mxfp8_quantize_cuda_3d(...) captures b by reference. Since b is the loop variable, this works correctly here because fn is consumed within the same iteration. But it's a common Python gotcha. Using a default argument (lambda b=b: ...) would make the capture explicit.

11. Env-var override for variant selection lacks documentation

torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py:63TORCHAO_MXFP8_CUTEDSL_VARIANT is a useful knob for benchmarking, but there's no mention of it in comments at the top of the file or in the PR description. Adding a brief note in the module docstring would help discoverability.


Code Quality

  • Well-structured warp specialization: The producer (warp 0) / consumer (warps 1..N) design is cleanly separated with clear comments explaining why a split load/store warp design was rejected.
  • Good use of cutlass.const_expr: Compile-time branching is used effectively throughout to eliminate dead code paths for each specialized variant.
  • Multi-stage pipelining: The 2-stage TMA pipeline is correctly implemented with proper barrier management (init, expect_tx, wait, phase tracking).
  • Proper functools.cache usage: JIT compilation results are cached, avoiding redundant compilation for the same parameter combination.

Architecture

The refactoring of quant.py to move lib.define() calls outside the if _mxfp8_cuda_kernels_available guard and to rename the internal CUDA function to _mxfp8_quantize_cuda_3d_cpp while making the public mxfp8_quantize_cuda_3d a dispatcher is a clean design that allows backend selection without breaking existing callers.


|

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 18, 2026

Can you add the benchmark numbers to the summary

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch 3 times, most recently from 078ad43 to 7d20418 Compare March 19, 2026 16:18
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

Current performance results:

shape cuda us plain us blocked us cuda TB/s plain TB/s blocked TB/s
(8, 5120, 5120) 117.0 105.0 107.0 5.436 6.068 5.951
(8, 7168, 5120) 162.0 144.0 147.0 5.496 6.187 6.038
(8, 7168, 7168) 224.0 201.0 205.0 5.573 6.197 6.077
(8, 8192, 8192) 291.0 262.0 267.0 5.594 6.212 6.102
Script to generate the results above
import re
import subprocess

SHAPES = [
    (8, 5120, 5120),
    (8, 7168, 5120),
    (8, 7168, 7168),
    (8, 8192, 8192),
]

COMMON_ARGS = [
    "python",
    "benchmarks/mx_formats/mxfp8_backend_bench.py",
    "--dtype",
    "bf16",
    "--scaling-mode",
    "rceil",
    "--backends",
    "cuda,cutedsl",
    "--check-results",
    "--stage-count",
    "2",
]

BLOCKED_ARGS = [
    "python",
    "benchmarks/mx_formats/mxfp8_backend_bench.py",
    "--dtype",
    "bf16",
    "--scaling-mode",
    "rceil",
    "--backends",
    "cuda,cutedsl",
    "--check-results",
    "--stage-count",
    "2",
    "--blocked-scale-output",
    "--blocked-k-blocks-per-tb",
    "1",
    "--blocked-swizzle",
    "5,2,5",
]

LINE_RE = re.compile(r"^\[(cuda|cutedsl)\s+\]\s+([0-9.]+) ms\s+([0-9.]+) TB/s")


def run_bench(args, e, n, k):
    cmd = args + ["--E", str(e), "--N", str(n), "--K", str(k)]
    out = subprocess.check_output(cmd, text=True, cwd="/data/openteams/scratch/ao")
    rows = {}
    for line in out.splitlines():
        m = LINE_RE.match(line.strip())
        if m:
            backend, ms, tbps = m.groups()
            rows[backend] = (float(ms), float(tbps))
    return rows


print("| shape | cuda us | plain us | blocked us | cuda TB/s | plain TB/s | blocked TB/s |")
print("|---|---:|---:|---:|---:|---:|---:|")

for e, n, k in SHAPES:
    plain = run_bench(COMMON_ARGS, e, n, k)
    blocked = run_bench(BLOCKED_ARGS, e, n, k)

    cuda_ms, cuda_tbps = plain["cuda"]
    plain_ms, plain_tbps = plain["cutedsl"]
    blocked_ms, blocked_tbps = blocked["cutedsl"]

    print(
        f"| ({e}, {n}, {k}) | "
        f"{cuda_ms * 1000:.1f} | "
        f"{plain_ms * 1000:.1f} | "
        f"{blocked_ms * 1000:.1f} | "
        f"{cuda_tbps:.3f} | "
        f"{plain_tbps:.3f} | "
        f"{blocked_tbps:.3f} |"
    )

if tidx == 0:
cpasync.prefetch_descriptor(tma_atom_in)
cpasync.prefetch_descriptor(tma_atom_out)
cute.arch.mbarrier_init(tma_mbar_ptr0, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use a pipeline object instead of raw mbar?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried that initially, and it was slower. I reconstructed the patch, still about 2% slower.

if cutlass.const_expr(BLOCKED_SCALE_OUTPUT):
padded_scale_cols = (
(n_blocks + cutlass.Int64(3)) // cutlass.Int64(4)
) * cutlass.Int64(4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your comment. I'm not sure if that I totally understand what it means, but basically this is this modulus floor diff math. This is the exact type of thing that I would expect you to be able to construct a layout, maybe even on CPU, and have it handle the indexing. I think Daniel linked the one kernel where I set up the output scale layouts, assuming row major construction.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. It is tiny bit faster this way too.

My comment was about the fact that your kernel does the "shuffle" in SMEM. I tried to do that here, but the performance was better with doing it on store. TBH, I'm not yet fully confident with CuTeDSL layouts and it took some experimentation to get this right, but indeed it should be more readable this way too.

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch 2 times, most recently from 9efac51 to 73506e6 Compare March 20, 2026 14:28
Copy link
Copy Markdown
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, awesome job here!

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 20, 2026

Here are some questions and comments regarding the remaining changes before the PR eventually merged:

  1. Do we make nvidia-cutlass-dsl and apache-tvm-ffi dependencies (i.e. do we add them to dev-requirements.txt) or it's user responsibility to install these if CuTeDSL functionality wanted?
  2. Do we want to remove C++ kernel altogether, or we want to keep an ability to switch between "cuda"/"cutedsl" backends for now?
  3. Do we want to keep the ability to save to "plain" layout, or we want to keep only the code saving to "blocked" layout?
  4. The test need updates. Testing "blocked" layout is needed. Also, rceil should be added into the test, as at the moment both "cuda" and "cutedsl" kernels fail for some tests if rceil added.
  5. Shall I remove my benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py benchmark from the PR? It was useful for both profiling and validation throughout the development, but benchmarks/mx_formats/mxfp8_backend_bench.py has CuTeDSL benchmarking added too now.

Edit: crossed out completed changes.

@danielvegamyhre
Copy link
Copy Markdown
Contributor

danielvegamyhre commented Mar 20, 2026

Here are some questions and comments regarding the remaining changes before the PR eventually merged:

  1. Do we make nvidia-cutlass-dsl and apache-tvm-ffi dependencies (i.e. do we add them to dev-requirements.txt) or it's user responsibility to install these if CuTeDSL functionality wanted?

i think runtime dependency is fine but it would be best to include as a dependency we ship. I will check if anyone on the team has objections to including it in our builds. (@drisspg any thoughts?)

  1. Do we want to remove C++ kernel altogether, or we want to keep an ability to switch between "cuda"/"cutedsl" backends for now?

We should delete the cuda c++ kernel once the cutedsl is one is done

  1. Do we want to keep the ability to save to "plain" layout, or we want to keep only the code saving to "blocked" layout?

We only need the blocked layout, we can delete the support for writing "plain" scales (row major)

  1. The test need updates. Testing "blocked" layout is needed. Also, rceil should be added into the test, as at the moment both "cuda" and "cutedsl" kernels fail for some tests if rceil added.

This is surprising, can you share a specific test case that fails so i can also take a look? Did you make sure to update the reference to_mx() call to use RCEIL as well?

  1. Shall I remove my benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py benchmark from the PR? It was useful for both profiling and validation throughout the development, but benchmarks/mx_formats/mxfp8_backend_bench.py has CuTeDSL benchmarking added too now.

Yes let's just keep one benchmark script and delete the other. I prefer keeping "benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py" since the structure is consistent with the rest of kernel benchmark scripts.

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 20, 2026

  1. The test need updates. Testing "blocked" layout is needed. Also, rceil should be added into the test, as at the moment both "cuda" and "cutedsl" kernels fail for some tests if rceil added.

This is surprising, can you share a specific test case that fails so i can also take a look? Did you make sure to update the reference to_mx() call to use RCEIL as well?

Indeed, scaling_mode=scaling_mode should be added to to_mx call. I'll do that, and add ScaleCalculationMode.RCEIL to the parametrization.

Edit: Pushed.

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 20, 2026

An additional question: Is it important to keep support for anything except CC 10.x? At some point, I had kernel working on anything CC 9.0 and above, but would need to retest.

Changed to check for 10.x only.

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch 2 times, most recently from 4175803 to 56e65ce Compare March 20, 2026 17:16
k_block_tiles = cute.ceil_div(K, 128)
n_block_tiles = padded_scale_cols // cutlass.Int64(4)
blocked_scale_layout = cute.make_layout(
((32, 4, k_block_tiles), (4, n_block_tiles)),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why isn't this something like (rest_n, rest_k, 32, 4, 4)? (just for my own understanding, i am newer to cute layouts)

Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it because we still need two "modes" for row and col dimensions to easily map the 512 byte scale factor tiles to a 2d coordinate space? so we have the (32,4, k_block_tiles) mode to select "which tile along the row dim", and (4, n_block_tiles) as "which tile along column dim" ?

Copy link
Copy Markdown
Collaborator Author

@alexsamardzic alexsamardzic Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why isn't this something like (rest_n, rest_k, 32, 4, 4)? (just for my own understanding, i am newer to cute layouts)

These two should be equivalent. I guess it boils down to which one seems easier to understand.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it because we still need two "modes" for row and col dimensions to easily map the 512 byte scale factor tiles to a 2d coordinate space? so we have the (32,4, k_block_tiles) mode to select "which tile along the row dim", and (4, n_block_tiles) as "which tile along column dim" ?

Yes, exactly.

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

For the reference, here is the latest version of benchmarking script used throughout the development (script removed from PR):

Benchmarking script
import argparse
from typing import Dict, Tuple

import torch

from torchao.prototype.moe_training.kernels.mxfp8.quant import mxfp8_quantize_cuda_3d
from torchao.prototype.mx_formats.utils import from_blocked


def _parse_backends(s: str) -> list[str]:
    out = [x.strip() for x in s.split(",") if x.strip()]
    for b in out:
        if b not in {"cuda", "cutedsl"}:
            raise ValueError(f"Unsupported backend={b!r}, expected cuda/cutedsl")
    return out


def _dtype_from_str(s: str) -> torch.dtype:
    if s == "bf16":
        return torch.bfloat16
    if s == "fp32":
        return torch.float32
    raise ValueError(f"Unsupported dtype={s}")


def _tbps(num_bytes: int, ms: float) -> float:
    return num_bytes / (ms / 1e3) / 1e12


def _benchmark(fn, warmup: int, iters: int) -> float:
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters


def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
    return (a.float() - b.float()).abs().max().item()


def _run_3d(args) -> None:
    dtype = _dtype_from_str(args.dtype)
    backends = _parse_backends(args.backends)
    E, N, K = args.E, args.N, args.K

    props = torch.cuda.get_device_properties(torch.cuda.current_device())
    cc = torch.cuda.get_device_capability()
    print(f"GPU: {props.name}")
    print(f"CC: {cc}")
    print(
        f"shape=(E,N,K)=({E},{N},{K}) dtype={dtype} scaling_mode={args.scaling_mode} "
        f"stage_count={args.stage_count}"
    )

    x = torch.randn((E, N, K), device="cuda", dtype=dtype) * 1000
    n_blocks = N // 32
    bytes_moved = (
        x.numel() * x.element_size()  # input
        + x.numel()
        * torch.tensor([], dtype=torch.float8_e4m3fn).element_size()  # output
        + (E * n_blocks * K)
        * torch.tensor([], dtype=torch.float8_e8m0fnu).element_size()  # scale
    )

    outs: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
    for b in backends:
        try:
            fn = lambda b=b: mxfp8_quantize_cuda_3d(  # noqa: E731
                x,
                block_size=32,
                scaling_mode=args.scaling_mode,
                backend=b,
                stage_count=args.stage_count,
            )
            ms = _benchmark(fn, args.warmup, args.iters)
            y, s = fn()
            outs[b] = (y, s)
            print(
                f"[{b:<10}]    {ms:.3f} ms    {_tbps(bytes_moved, ms):.3f} TB/s   "
                f"y_stride={tuple(y.stride())} s_shape={tuple(s.shape)}"
            )
        except Exception as e:
            print(f"[{b:<10}] FAILED after 0.00s: {type(e).__name__}: {e}")

    if args.check_results and "cuda" in outs:
        y_ref, s_ref = outs["cuda"]
        for b in backends:
            if b == "cuda" or b not in outs:
                continue
            y, s = outs[b]
            dy = _max_abs_diff(y_ref, y)
            if b == "cutedsl":
                s = torch.stack(
                    [
                        from_blocked(s[e], K, n_blocks).transpose(-2, -1).contiguous()
                        for e in range(E)
                    ],
                    dim=0,
                ).to(s_ref.dtype)
            ds = _max_abs_diff(s_ref, s)
            print(f"diff(cuda vs {b}): y_max_abs={dy} s_max_abs={ds}")
            ok = dy <= args.atol and ds <= args.atol
            print(f"check(cuda vs {b}): {'PASS' if ok else 'FAIL'} (atol={args.atol})")
            if not ok:
                raise RuntimeError(
                    f"Result mismatch for backend={b}: y_diff={dy}, s_diff={ds}, atol={args.atol}"
                )


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dtype", choices=("bf16", "fp32"), default="bf16")
    parser.add_argument("--scaling-mode", choices=("floor", "rceil"), default="floor")
    parser.add_argument("--backends", default="cuda,cutedsl")
    parser.add_argument("--warmup", type=int, default=20)
    parser.add_argument("--iters", type=int, default=100)
    parser.add_argument("--check-results", action="store_true")
    parser.add_argument("--atol", type=float, default=0.0)
    parser.add_argument("--stage-count", type=int, default=2)

    parser.add_argument("--E", type=int, default=8)
    parser.add_argument("--N", type=int, default=7168)
    parser.add_argument("--K", type=int, default=2048)

    args = parser.parse_args()
    _run_3d(args)


if __name__ == "__main__":
    main()
Command to run the script (if saved as bench.py)
for shape in \
"8 5120 5120" \
"8 7168 5120" \
"8 7168 7168" \
"8 8192 8192"
do
  read -r E N K <<< "$shape"
  python bench.py \
    --dtype bf16 \
    --scaling-mode rceil \
    --backends cuda,cutedsl \
    --check-results \
    --stage-count 2 \
    --E "$E" --N "$N" --K "$K"
done

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

alexsamardzic commented Mar 20, 2026

  1. Do we want to remove C++ kernel altogether, or we want to keep an ability to switch between "cuda"/"cutedsl" backends for now?

We should delete the cuda c++ kernel once the cutedsl is one is done

  1. Do we want to keep the ability to save to "plain" layout, or we want to keep only the code saving to "blocked" layout?

We only need the blocked layout, we can delete the support for writing "plain" scales (row major)

I made the change for 3., not yet for 2., and CUDA C++ kernel is still the default. The thing is, the contract for mxfp8_quantize_cuda_3d will change when I do 2. Do we treat this one as for internal torchao use only?

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 20, 2026

  1. yeah lets make optional for now

delete c++ and only right swizzled for testing we can just deswizzle the swizzled output to plain if need be. And then

weight_e4m3, weight_scales = mxfp8_quantize_cuda_3d(
weight._data if hasattr(weight, "_data") else weight,
block_size=block_size,
scaling_mode=scale_calculation_mode.value.lower(),
)
# Transpose scales to align with torch API requirement:
# (E, N//block_size, K) -> (E, K, N//block_size)
weight_scales = weight_scales.transpose(-2, -1)
# Convert scales to blocked format
grad_output_scales_blocked = mx_block_rearrange_2d_M_groups_cuda(
grad_output_scales, group_end_offsets
)
we just merge into 1 call

@alexsamardzic alexsamardzic force-pushed the asamardzic/cutedsl-quantize-3d branch from 56e65ce to 920db9c Compare March 20, 2026 19:57
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

  1. yeah lets make optional for now

Updated. The operator will report missing packages in case of failure.

delete c++ and only right swizzled for testing we can just deswizzle the swizzled output to plain if need be. And then
[ ... ]
we just merge into 1 call

Done.

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

All updates made, I think this is ready for merge now.

To test/benchmark:

pytest test/prototype/moe_training/test_kernels.py -k cuda_mx_dim1_3d_numerics

pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -k dq_fwd_bwd

python -m benchmarks.prototype.moe_training.mxfp8.bench_quantize_3d

python -m benchmarks.prototype.moe_training.bench_2d_3d_grouped_gemm

@danielvegamyhre
Copy link
Copy Markdown
Contributor

All updates made, I think this is ready for merge now.

To test/benchmark:

pytest test/prototype/moe_training/test_kernels.py -k cuda_mx_dim1_3d_numerics

pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py -k dq_fwd_bwd

python -m benchmarks.prototype.moe_training.mxfp8.bench_quantize_3d

python -m benchmarks.prototype.moe_training.bench_2d_3d_grouped_gemm

I did some final manual tests and everything looks good, planning to land once CI is done!

@danielvegamyhre
Copy link
Copy Markdown
Contributor

landing, test error unrelated:

AttributeError: 'LinearActivationQuantizedTensor' object has no attribute 'tensor_impl'

To execute this test, run the following from the base repo dir:
    python test/dtypes/test_affine_quantized_tensor_parallel.py TestInt8dqAffineQuantizedTensorParallel.test_tp_bfloat16
    ```

@danielvegamyhre danielvegamyhre merged commit d7509a6 into main Mar 20, 2026
22 of 23 checks passed
@alexsamardzic alexsamardzic deleted the asamardzic/cutedsl-quantize-3d branch March 21, 2026 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants