From 821bd2b7985f26743ef7644a60e7380cb16e8c26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 07:41:27 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 22 ++++++++++-- torchao/testing/training/roofline_utils.py | 41 +++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 4bf54538df..547b0a40e4 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -180,7 +180,7 @@ def get_gemm_times( scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: - assert False, "TODO add cutlass mx gemm here" + assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}" def do_matmul(A, B): return torch._scaled_mm( @@ -233,6 +233,20 @@ def run( print(f"mx_recipe_name: {mx_recipe_name}") print(f"enable_fusion_modeling: {enable_fusion_modeling}") + assert mx_recipe_name in ( + # real mxfp8_cublas recipe + "mxfp8_cublas", + # real mxfp8_cublas_rceil recipe + "mxfp8_cublas_rceil", + # modeling of what mxfp8 with 32x32 block size and without gemm + # operand layout restrictions would look like + "mxfp8_32x32_flexible_gemm_layout", + # modeling of what mxfp8 with 32x32 block size for weight + "mxfp8_32x32_weight", + # real mxfp4_cutlass recipe + "mxfp4_cutlass", + ), f"unsupported {mx_recipe_name=}" + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_float8_mem_sympy( @@ -309,7 +323,11 @@ def run( rb_fp8_gemm_ratio = -1 if do_benchmarks: - assert mx_recipe_name != "mxfp4_cutlass", "unsupported" + assert mx_recipe_name not in ( + "mxfp4_cutlass", + "mxfp8_32x32_flexible_gemm_layout", + "mxfp8_32x32_weight", + ), f"do_benchmarks unsupported with {mx_recipe_name=}" # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index f57705333a..6610654bf1 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s( else: assert False, "unsupported" + elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout": + # modeling the following: + # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense + # across dim0 and dim1 + # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in + # PyTorch right now) + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + + elif mx_recipe_name == "mxfp8_32x32_weight": + # modeling the following: + # 1. mxfp8 scaling with 32x32 weights, so the format makes sense + # across dim0 and dim1. input and grad_output still 1x32. + + if tensor_role in ("input", "grad_output"): + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + + elif tensor_role == "weight": + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2 + + else: + assert False, "unsupported" + + res_bytes = [kernel_1_rw, kernel_2_rw] + else: assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil", "mxfp4_cutlass", - ), "unsupported" + ), f"unsupported {mx_recipe_name=}" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... # kernel 1: x_bf16 -> x_mxfp8_dim0 From 5bd4e3b4ff6617d6bb7eec8b13f6be99b1aeb40d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 13:32:59 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchao/testing/training/roofline_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 6610654bf1..e391a4d44b 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -207,6 +207,7 @@ def get_tensor_memory_traffic_ovhd_s( # across dim0 and dim1. input and grad_output still 1x32. if tensor_role in ("input", "grad_output"): + # TODO(future): update all of the mx rooflines to just read once # kernel 1: x_bf16 -> x_mxfp8_dim0 # kernel 2: x_bf16 -> x_mxfp8_dim1 if fuse_with_prev: From ea2d54f578ef0fb39d0556699429598419ce8927 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 14:09:19 -0700 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 106 +++++++++++++----- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index fbfead161a..6c8113e8cb 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,6 +38,14 @@ ) import torchao +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -80,40 +88,67 @@ def get_gemm_times( fast_accum: bool, recipe_name: Optional[str], ): - assert recipe_name in {"rowwise"}, ( - "Only support real benchmarks for 'rowwise' recipe for now" - ) device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - e4m3_dtype = torch.float8_e4m3fn - if torch.version.hip and torch.cuda.is_available() and is_MI300(): - e4m3_dtype = torch.float8_e4m3fnuz - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) - B = ( - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) - .view(d2) - .t() - .contiguous() - .t() - ) + if recipe_name in ("mxfp4_cutlass", "nvfp4"): + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( + d1 + ) + B = ( + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) + .t() + .contiguous() + .t() + .view(d2) + ) + else: + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif recipe_name == "mxfp8_cublas": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "mxfp4_cutlass": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "nvfp4": + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + else: assert False, "unsupported" def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if recipe_name == "mxfp4_cutlass": + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) + if recipe_name == "nvfp4": + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) @@ -259,12 +294,33 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - # for now, use TORCH. In the future might be interesting - # to benchmark AUTO and FBGEMM. - kernel_preference=KernelPreference.TORCH, - ) + if recipe_name == "rowwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + elif recipe_name == "mxfp8_cublas": + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + ) + elif recipe_name == "mxfp4_cutlass": + config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + elif recipe_name == "nvfp4": + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=False, + ) + else: + assert False, "unsupported" + m_fp8_dyn = copy.deepcopy(m_orig) quantize_(m_fp8_dyn, config)