diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 3365fba923..271fbbf530 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -168,6 +168,7 @@ def run( shape_gen_name: str = "pow2", n_limit: Optional[int] = None, save_profile_traces: bool = False, + enable_fusion_modeling: bool = False, ): """ Args: @@ -176,6 +177,7 @@ def run( * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations # `save_profile_traces (optional)`: if True, saves profiling traces + # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -184,6 +186,7 @@ def run( ["recipe_name", recipe_name], ["do_benchmarks", do_benchmarks], ["shape_gen_name", shape_gen_name], + ["enable_fusion_modeling", enable_fusion_modeling], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) @@ -194,6 +197,7 @@ def run( K, N, recipe_name, + # TODO(future): also enable fusion modeling here ) bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None) @@ -287,9 +291,11 @@ def run( b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: # create the model - m_orig = ( - nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16() - ) + if not enable_fusion_modeling: + m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False)) + else: + m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False)) + m_orig = m_orig.cuda().bfloat16() x = torch.randn( M_val, K_val, dtype=torch.bfloat16, device="cuda" ).requires_grad_()