diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 271fbbf530..ea28d3236e 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -46,6 +46,7 @@ NVFP4InferenceConfig, NVFP4MMConfig, ) +from torchao.prototype.mx_formats.utils import to_blocked from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -134,12 +135,18 @@ def get_gemm_times( elif recipe_name == "mxfp8_cublas": scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) elif recipe_name == "mxfp4_cutlass": scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) elif recipe_name == "nvfp4": scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) else: assert False, "unsupported" @@ -166,6 +173,9 @@ def run( recipe_name: str, do_benchmarks: bool = True, shape_gen_name: str = "pow2", + M: Optional[int] = None, + K: Optional[int] = None, + N: Optional[int] = None, n_limit: Optional[int] = None, save_profile_traces: bool = False, enable_fusion_modeling: bool = False, @@ -174,7 +184,8 @@ def run( Args: * `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*) * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked - * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` + * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom` + * `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN * `n_limit (optional)`: if specified, only runs `n_limit` iterations # `save_profile_traces (optional)`: if True, saves profiling traces # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm @@ -187,9 +198,13 @@ def run( ["do_benchmarks", do_benchmarks], ["shape_gen_name", shape_gen_name], ["enable_fusion_modeling", enable_fusion_modeling], + ["MKN", f"{M} {K} {N}"], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) + # reassign user specified MKN, so we can use them for sympy + user_M, user_K, user_N = M, K, N + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( @@ -245,7 +260,7 @@ def run( ] results = [] - name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None) + name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N) for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)): if n_limit is not None and idx >= n_limit: