From b3a1be99c3a3315a3c5dc119f11902d54bb2287c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 21 Oct 2025 14:58:56 -0700 Subject: [PATCH] mxfp8 inference roofline: add fusion to observed Summary: Adds option to benchmark with relu -> linear to capture the impact of fusing the activation to the quant kernel Test Plan: ```bash (pt_nightly_312_2) [vasiliy@devgpu023.atn1 ~/local/ao (20251021_inference_fusion_modeling)]$ python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/test.csv --recipe_name mxfp8_cublas --shape_gen_name pow2_extended --enable_fusion_modeling True ``` Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/float8_inference_roofline.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 3365fba923..271fbbf530 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -168,6 +168,7 @@ def run( shape_gen_name: str = "pow2", n_limit: Optional[int] = None, save_profile_traces: bool = False, + enable_fusion_modeling: bool = False, ): """ Args: @@ -176,6 +177,7 @@ def run( * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations # `save_profile_traces (optional)`: if True, saves profiling traces + # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -184,6 +186,7 @@ def run( ["recipe_name", recipe_name], ["do_benchmarks", do_benchmarks], ["shape_gen_name", shape_gen_name], + ["enable_fusion_modeling", enable_fusion_modeling], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) @@ -194,6 +197,7 @@ def run( K, N, recipe_name, + # TODO(future): also enable fusion modeling here ) bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None) @@ -287,9 +291,11 @@ def run( b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: # create the model - m_orig = ( - nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16() - ) + if not enable_fusion_modeling: + m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False)) + else: + m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False)) + m_orig = m_orig.cuda().bfloat16() x = torch.randn( M_val, K_val, dtype=torch.bfloat16, device="cuda" ).requires_grad_()