From faec97aa0e30f95e0db4b6ec34322b4d1e305e6f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 21 Oct 2025 14:21:56 -0700 Subject: [PATCH] fix mxfp8 matmul benchmark Summary: Adds padding to the scales to properly support shapes where M % 128 != 0 Test Plan: ``` python benchmarks/float8/bench_matmul.py --shape_gen_name custom --recipe mxfp8_cublas --M 17 --K 32 --N 16 ``` Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/bench_matmul.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index c6499e692d..fd0489dbc1 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -17,6 +17,7 @@ from torchao.ops import mx_fp4_bf16 from torchao.prototype.mx_formats.mx_tensor import to_mx +from torchao.prototype.mx_formats.utils import to_blocked from torchao.testing.training.roofline_utils import get_specs from torchao.utils import is_MI300 @@ -125,10 +126,16 @@ def run( elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"): scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + # pad if needed + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) elif recipe == "nvfp4": # Use the blockwise scales from nvfp4_quantize scale_a = A_scales.view(torch.float8_e4m3fn) scale_b = B_scales.view(torch.float8_e4m3fn) + # pad if needed + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) else: assert False, f"unknown recipe {recipe}"