Add CuTeDSL kernel for 3D tensor quantization to MXFP8#4090
Add CuTeDSL kernel for 3D tensor quantization to MXFP8#4090danielvegamyhre merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4090
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 920db9c with merge base e40a817 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
The kernel works on everything Hopper and above. It supports both floor and rceil rounding. To test: pytest test/prototype/moe_training/test_kernels.py -k cuda_mx_dim1_3d_numericsTo benchmark: for shape in \
"8 5120 5120" \
"8 7168 5120" \
"8 7168 7168" \
"8 8192 8192"
do
read -r E N K <<< "$shape"
python benchmarks/mx_formats/mxfp8_backend_bench.py \
--dtype bf16 \
--scaling-mode floor \
--backends cuda,cutedsl \
--check-results \
--stage-count 2 \
--E "$E" --N "$N" --K "$K"
doneBenchmarking results: Note: nvidia-cutlass-dsl and apache-tvm-ffi packages are now dependencies. |
c9d0f28 to
8412e87
Compare
danielvegamyhre
left a comment
There was a problem hiding this comment.
This is looking great @alexsamardzic! I think the design makes sense, and appreciate your comments on what else you tried that didn't work to help me understand the reasoning behind the design.
I added some suggestions and questions, also if you wouldn't mind please check out this short blog on a cutedsl MXFP8 quantizer that they claim sustains 6+ TB/s for their shapes, and see if any ideas from it can be applied here: https://blog.fal.ai/chasing-6-tb-s-an-mxfp8-quantizer-on-blackwell/
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
| q_fp8_vals4.store(q_fp8_vec4) | ||
|
|
||
| for j in range(chunk_vec): | ||
| sOUT_tile[0, sout_base + j, k_rel] = q_fp8_vals4[j] |
There was a problem hiding this comment.
(new to cutedsl, apologies if this comment is off base, let me know):
this will do 4 vectorized 4 byte rmem->smem writes, right? from what i can see this would appear to be simple row major layout, is that right?
we need the smem layout to be in ((32,4),4) blocked layout for tcgen05.mma usage, and write directly to that layout in smem before we do the TMA copy smem->gmem. if we store in simple row major format we pay the tax of an extra kernel dispatch doing this layout transformation on the scales, which is what we are currently doing and would like to avoid.
probably need to do scattered / uncoalesced 4 byte vectorized stores to smem for this, just make sure to avoid single byte shared stores / STS.U8
if you aren't familiar with this layout i added some diagrams to our recent MXFP8 training blog that help visualize it: https://pytorch.org/blog/mxfp8-training-for-moes-1-3x-training-speedup-vs-bf16-for-llama4-scout-on-gb200-cluster-using-torchao-and-torchtitan/
There was a problem hiding this comment.
quack code pointer that may be useful here:
There was a problem hiding this comment.
Updated to write one uint32 instead of four fp8, another 4% performance improvement here.
For the output layout update, see below.
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
cedb175 to
6c4d33a
Compare
|
@danielvegamyhre: Thanks for all the comments, in particular for pointing to possible performance improvements. I've made updates according to all of your suggestions, and we're now getting pretty close to the 6.4 TB/s - here are latest rceil benchmarking results on B200 (note that these are for the default output layout, new variant with blocked output layout is significantly slower): I'll check the blog you pointed to for further improvement ideas. |
6c4d33a to
168c66d
Compare
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
|
@claude also do a PR review pass here |
|
Claude finished @drisspg's task in 2m 56s —— View job PR Review: Add CuTeDSL kernel for 3D tensor quantization to MXFP8
OverallThis PR adds a CuTeDSL-based kernel for quantizing 3D tensors to MXFP8 format, targeting Hopper (SM90) and above. The kernel uses warp-specialized TMA with a producer/consumer model and achieves solid throughput numbers (up to ~6.2 TB/s on B200 with rceil). The architecture is well-documented with inline comments explaining design decisions that were tried and rejected. Issues1. Bug: Benchmark uses wrong backend (already flagged by @danielvegamyhre)
Actually, looking more carefully at the benchmark: the warmup call on lines 137-142 runs the function once before benchmarking. This should trigger JIT compilation. The issue @danielvegamyhre flagged was that 2. Unreachable code path in
|
|
Can you add the benchmark numbers to the summary |
078ad43 to
7d20418
Compare
|
Current performance results:
Script to generate the results aboveimport re
import subprocess
SHAPES = [
(8, 5120, 5120),
(8, 7168, 5120),
(8, 7168, 7168),
(8, 8192, 8192),
]
COMMON_ARGS = [
"python",
"benchmarks/mx_formats/mxfp8_backend_bench.py",
"--dtype",
"bf16",
"--scaling-mode",
"rceil",
"--backends",
"cuda,cutedsl",
"--check-results",
"--stage-count",
"2",
]
BLOCKED_ARGS = [
"python",
"benchmarks/mx_formats/mxfp8_backend_bench.py",
"--dtype",
"bf16",
"--scaling-mode",
"rceil",
"--backends",
"cuda,cutedsl",
"--check-results",
"--stage-count",
"2",
"--blocked-scale-output",
"--blocked-k-blocks-per-tb",
"1",
"--blocked-swizzle",
"5,2,5",
]
LINE_RE = re.compile(r"^\[(cuda|cutedsl)\s+\]\s+([0-9.]+) ms\s+([0-9.]+) TB/s")
def run_bench(args, e, n, k):
cmd = args + ["--E", str(e), "--N", str(n), "--K", str(k)]
out = subprocess.check_output(cmd, text=True, cwd="/data/openteams/scratch/ao")
rows = {}
for line in out.splitlines():
m = LINE_RE.match(line.strip())
if m:
backend, ms, tbps = m.groups()
rows[backend] = (float(ms), float(tbps))
return rows
print("| shape | cuda us | plain us | blocked us | cuda TB/s | plain TB/s | blocked TB/s |")
print("|---|---:|---:|---:|---:|---:|---:|")
for e, n, k in SHAPES:
plain = run_bench(COMMON_ARGS, e, n, k)
blocked = run_bench(BLOCKED_ARGS, e, n, k)
cuda_ms, cuda_tbps = plain["cuda"]
plain_ms, plain_tbps = plain["cutedsl"]
blocked_ms, blocked_tbps = blocked["cutedsl"]
print(
f"| ({e}, {n}, {k}) | "
f"{cuda_ms * 1000:.1f} | "
f"{plain_ms * 1000:.1f} | "
f"{blocked_ms * 1000:.1f} | "
f"{cuda_tbps:.3f} | "
f"{plain_tbps:.3f} | "
f"{blocked_tbps:.3f} |"
) |
| if tidx == 0: | ||
| cpasync.prefetch_descriptor(tma_atom_in) | ||
| cpasync.prefetch_descriptor(tma_atom_out) | ||
| cute.arch.mbarrier_init(tma_mbar_ptr0, 1) |
There was a problem hiding this comment.
could we use a pipeline object instead of raw mbar?
There was a problem hiding this comment.
Tried that initially, and it was slower. I reconstructed the patch, still about 2% slower.
| if cutlass.const_expr(BLOCKED_SCALE_OUTPUT): | ||
| padded_scale_cols = ( | ||
| (n_blocks + cutlass.Int64(3)) // cutlass.Int64(4) | ||
| ) * cutlass.Int64(4) |
There was a problem hiding this comment.
I saw your comment. I'm not sure if that I totally understand what it means, but basically this is this modulus floor diff math. This is the exact type of thing that I would expect you to be able to construct a layout, maybe even on CPU, and have it handle the indexing. I think Daniel linked the one kernel where I set up the output scale layouts, assuming row major construction.
There was a problem hiding this comment.
Updated. It is tiny bit faster this way too.
My comment was about the fact that your kernel does the "shuffle" in SMEM. I tried to do that here, but the performance was better with doing it on store. TBH, I'm not yet fully confident with CuTeDSL layouts and it took some experimentation to get this right, but indeed it should be more readable this way too.
torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_3d.py
Outdated
Show resolved
Hide resolved
9efac51 to
73506e6
Compare
drisspg
left a comment
There was a problem hiding this comment.
Looks great, awesome job here!
|
Here are some questions and comments regarding the remaining changes before the PR eventually merged:
Edit: crossed out completed changes. |
i think runtime dependency is fine but it would be best to include as a dependency we ship. I will check if anyone on the team has objections to including it in our builds. (@drisspg any thoughts?)
We should delete the cuda c++ kernel once the cutedsl is one is done
We only need the blocked layout, we can delete the support for writing "plain" scales (row major)
This is surprising, can you share a specific test case that fails so i can also take a look? Did you make sure to update the reference
Yes let's just keep one benchmark script and delete the other. I prefer keeping "benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py" since the structure is consistent with the rest of kernel benchmark scripts. |
Indeed, Edit: Pushed. |
|
Changed to check for 10.x only. |
4175803 to
56e65ce
Compare
| k_block_tiles = cute.ceil_div(K, 128) | ||
| n_block_tiles = padded_scale_cols // cutlass.Int64(4) | ||
| blocked_scale_layout = cute.make_layout( | ||
| ((32, 4, k_block_tiles), (4, n_block_tiles)), |
There was a problem hiding this comment.
curious, why isn't this something like (rest_n, rest_k, 32, 4, 4)? (just for my own understanding, i am newer to cute layouts)
There was a problem hiding this comment.
is it because we still need two "modes" for row and col dimensions to easily map the 512 byte scale factor tiles to a 2d coordinate space? so we have the (32,4, k_block_tiles) mode to select "which tile along the row dim", and (4, n_block_tiles) as "which tile along column dim" ?
There was a problem hiding this comment.
curious, why isn't this something like (rest_n, rest_k, 32, 4, 4)? (just for my own understanding, i am newer to cute layouts)
These two should be equivalent. I guess it boils down to which one seems easier to understand.
There was a problem hiding this comment.
is it because we still need two "modes" for row and col dimensions to easily map the 512 byte scale factor tiles to a 2d coordinate space? so we have the
(32,4, k_block_tiles)mode to select "which tile along the row dim", and(4, n_block_tiles)as "which tile along column dim" ?
Yes, exactly.
|
For the reference, here is the latest version of benchmarking script used throughout the development (script removed from PR): Benchmarking scriptimport argparse
from typing import Dict, Tuple
import torch
from torchao.prototype.moe_training.kernels.mxfp8.quant import mxfp8_quantize_cuda_3d
from torchao.prototype.mx_formats.utils import from_blocked
def _parse_backends(s: str) -> list[str]:
out = [x.strip() for x in s.split(",") if x.strip()]
for b in out:
if b not in {"cuda", "cutedsl"}:
raise ValueError(f"Unsupported backend={b!r}, expected cuda/cutedsl")
return out
def _dtype_from_str(s: str) -> torch.dtype:
if s == "bf16":
return torch.bfloat16
if s == "fp32":
return torch.float32
raise ValueError(f"Unsupported dtype={s}")
def _tbps(num_bytes: int, ms: float) -> float:
return num_bytes / (ms / 1e3) / 1e12
def _benchmark(fn, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters
def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
return (a.float() - b.float()).abs().max().item()
def _run_3d(args) -> None:
dtype = _dtype_from_str(args.dtype)
backends = _parse_backends(args.backends)
E, N, K = args.E, args.N, args.K
props = torch.cuda.get_device_properties(torch.cuda.current_device())
cc = torch.cuda.get_device_capability()
print(f"GPU: {props.name}")
print(f"CC: {cc}")
print(
f"shape=(E,N,K)=({E},{N},{K}) dtype={dtype} scaling_mode={args.scaling_mode} "
f"stage_count={args.stage_count}"
)
x = torch.randn((E, N, K), device="cuda", dtype=dtype) * 1000
n_blocks = N // 32
bytes_moved = (
x.numel() * x.element_size() # input
+ x.numel()
* torch.tensor([], dtype=torch.float8_e4m3fn).element_size() # output
+ (E * n_blocks * K)
* torch.tensor([], dtype=torch.float8_e8m0fnu).element_size() # scale
)
outs: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
for b in backends:
try:
fn = lambda b=b: mxfp8_quantize_cuda_3d( # noqa: E731
x,
block_size=32,
scaling_mode=args.scaling_mode,
backend=b,
stage_count=args.stage_count,
)
ms = _benchmark(fn, args.warmup, args.iters)
y, s = fn()
outs[b] = (y, s)
print(
f"[{b:<10}] {ms:.3f} ms {_tbps(bytes_moved, ms):.3f} TB/s "
f"y_stride={tuple(y.stride())} s_shape={tuple(s.shape)}"
)
except Exception as e:
print(f"[{b:<10}] FAILED after 0.00s: {type(e).__name__}: {e}")
if args.check_results and "cuda" in outs:
y_ref, s_ref = outs["cuda"]
for b in backends:
if b == "cuda" or b not in outs:
continue
y, s = outs[b]
dy = _max_abs_diff(y_ref, y)
if b == "cutedsl":
s = torch.stack(
[
from_blocked(s[e], K, n_blocks).transpose(-2, -1).contiguous()
for e in range(E)
],
dim=0,
).to(s_ref.dtype)
ds = _max_abs_diff(s_ref, s)
print(f"diff(cuda vs {b}): y_max_abs={dy} s_max_abs={ds}")
ok = dy <= args.atol and ds <= args.atol
print(f"check(cuda vs {b}): {'PASS' if ok else 'FAIL'} (atol={args.atol})")
if not ok:
raise RuntimeError(
f"Result mismatch for backend={b}: y_diff={dy}, s_diff={ds}, atol={args.atol}"
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", choices=("bf16", "fp32"), default="bf16")
parser.add_argument("--scaling-mode", choices=("floor", "rceil"), default="floor")
parser.add_argument("--backends", default="cuda,cutedsl")
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iters", type=int, default=100)
parser.add_argument("--check-results", action="store_true")
parser.add_argument("--atol", type=float, default=0.0)
parser.add_argument("--stage-count", type=int, default=2)
parser.add_argument("--E", type=int, default=8)
parser.add_argument("--N", type=int, default=7168)
parser.add_argument("--K", type=int, default=2048)
args = parser.parse_args()
_run_3d(args)
if __name__ == "__main__":
main()Command to run the script (if saved as bench.py)for shape in \
"8 5120 5120" \
"8 7168 5120" \
"8 7168 7168" \
"8 8192 8192"
do
read -r E N K <<< "$shape"
python bench.py \
--dtype bf16 \
--scaling-mode rceil \
--backends cuda,cutedsl \
--check-results \
--stage-count 2 \
--E "$E" --N "$N" --K "$K"
done |
I made the change for 3., not yet for 2., and CUDA C++ kernel is still the default. The thing is, the contract for |
delete c++ and only right swizzled for testing we can just ao/torchao/prototype/moe_training/mxfp8_grouped_mm.py Lines 447 to 460 in cd0e858 |
56e65ce to
920db9c
Compare
Updated. The operator will report missing packages in case of failure.
Done. |
|
All updates made, I think this is ready for merge now. To test/benchmark: |
I did some final manual tests and everything looks good, planning to land once CI is done! |
|
landing, test error unrelated: |
No description provided.