diff --git a/.gitignore b/.gitignore index d8c3199a1e..876ed78130 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ aten/build/ aten/src/ATen/Config.h aten/src/ATen/cuda/CUDAConfig.h benchmarks/.data +benchmarks/data caffe2/cpp_test/ dist/ docs/build/ diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 121b9fc7d3..fbfead161a 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -29,6 +29,7 @@ import torch import torch.nn as nn import tqdm +from tabulate import tabulate from torch.profiler import ProfilerActivity, profile from utils import ( get_gpu_kernel_gemm_time_s, @@ -77,8 +78,11 @@ def get_gemm_times( K: int, N: int, fast_accum: bool, - float8_recipe_name: Optional[str], + recipe_name: Optional[str], ): + assert recipe_name in {"rowwise"}, ( + "Only support real benchmarks for 'rowwise' recipe for now" + ) device = torch.device("cuda") # bf16 time @@ -100,7 +104,7 @@ def get_gemm_times( .contiguous() .t() ) - if float8_recipe_name in ("rowwise"): + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) else: @@ -118,26 +122,27 @@ def do_matmul(A, B): def run( outfile: str, + recipe_name: str, do_benchmarks: bool = True, shape_gen_name: str = "pow2", n_limit: Optional[int] = None, - float8_recipe_name: Optional[str] = None, ): """ Args: + * `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*) * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations """ - - assert float8_recipe_name is not None, "unsupported" - - print(f"GPU: {torch.cuda.get_device_name(0)}") - print(f"torch version: {torch.__version__}") - print(f"torchao version: {torchao.__version__}") - print(f"do_benchmarks: {do_benchmarks}") - print(f"shape_gen_name: {shape_gen_name}") - print(f"float8_recipe_name: {float8_recipe_name}") + config_table = [ + ["GPU", torch.cuda.get_device_name(0)], + ["torch version", torch.__version__], + ["torchao version", torchao.__version__], + ["recipe_name", recipe_name], + ["do_benchmarks", do_benchmarks], + ["shape_gen_name", shape_gen_name], + ] + print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) M, K, N = sympy.symbols("M K N") @@ -145,14 +150,19 @@ def run( M, K, N, - float8_recipe_name, - ) - bf16_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.bfloat16, None, None - ) - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float8_e4m3fn, float8_recipe_name, None + recipe_name, ) + bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None) + + if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float4_e2m1fn_x2, recipe_name + ) + else: + gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, gemm_recipe_name + ) print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) @@ -219,7 +229,7 @@ def run( K_val, N_val, True, - float8_recipe_name, + recipe_name, ) b_bf16_gemm_time_s = bf16_g1 b_fp8_gemm_time_s = f8_g1 @@ -261,6 +271,8 @@ def run( m_fp8_dyn = torch.compile(m_fp8_dyn) b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) + results.append( [ M_val, @@ -273,7 +285,7 @@ def run( r_fp8_ovhd_time_s, # roofline - gemm + overhead, and speedup r_fp8_gemm_time_s + r_fp8_ovhd_time_s, - r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s), + r_speedup, # benchmarks - gemm b_bf16_gemm_time_s, b_fp8_gemm_time_s, diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index ec8dd54239..4bf54538df 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -214,6 +214,7 @@ def run( * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `gemm_cache_filename (optional)`: file to cache gemm benchmark results * `n_limit (optional)`: if specified, only runs `n_limit` iterations + * `mx_recipe_name (optional)`: MX format recipe * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead """ diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index d91f25fedb..f57705333a 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -12,6 +12,7 @@ BYTES_PER_EL_FLOAT4 = 0.5 BYTES_PER_EL_FLOAT8 = 1 BYTES_PER_EL_BF16 = 2 +BYTES_PER_EL_FLOAT32 = 4 gpu_name_to_specs = { "NVIDIA H100": { @@ -228,7 +229,7 @@ def get_individual_gemm_time_sympy( K: sympy.Symbol, N: sympy.Symbol, dtype, - mx_recipe_name, + mx_recipe_name: Optional[str], gpu_name: Optional[str] = None, ) -> sympy.Symbol: # compute bound @@ -241,7 +242,7 @@ def get_individual_gemm_time_sympy( elif dtype is torch.float4_e2m1fn_x2: peak_tops = specs["fp4_peak_tops"] else: - assert False, "unsupported" + assert False, f"unsupported dtype: {dtype}" compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"] # memory bound @@ -249,19 +250,16 @@ def get_individual_gemm_time_sympy( num_writes = M * N if mx_recipe_name is not None: - assert mx_recipe_name in ( - "mxfp8_emulated", - "mxfp8_cublas", - "mxfp8_cublas_rceil", - "mxfp4_cutlass", - ), "unsupported" + assert mx_recipe_name.startswith(("mxfp8", "mxfp4", "nvfp4")), ( + f"Unsupported recipe {mx_recipe_name}" + ) assert dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, torch.float4_e2m1fn_x2, ), "unsupported" # adjust reads for MX scaling - block_size = 32 + block_size = 32 if mx_recipe_name.startswith("mx") else 16 num_scale_reads = num_reads // block_size # note: e8m0 bytes per element is the same as for e4m3|e5m2 num_reads = num_reads + num_scale_reads @@ -274,7 +272,7 @@ def get_individual_gemm_time_sympy( elif dtype is torch.float4_e2m1fn_x2: bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16 else: - assert False, "unsupported" + assert False, f"unsupported dtype: {dtype}" mem_gemm_time_s = ( bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] ) @@ -375,28 +373,68 @@ def get_inference_tensor_memory_traffic_ovhd_s( dim0, dim1, tensor_role: str, - float8_recipe_name: Optional[str], + recipe_name: Optional[str], fuse_with_prev=False, ) -> List[Union[sympy.Symbol, float]]: """ Inference version of `get_tensor_memory_traffic_ovhd_s`. The only thing happening here is we quantize the activation. """ - assert float8_recipe_name == "rowwise", "unsupported" assert fuse_with_prev is False, "unsupported" + assert tensor_role == "input", "inference only quantizes input activations" # assumes input bf16, output f8 numel = dim0 * dim1 res_bytes = None - assert tensor_role == "input" - # x_bf16 = ... - # kernel 1: x_bf16 -> x_fp8 - kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel - res_bytes = [ - kernel_1_rw, - ] + allowed_recipes = {"tensorwise", "rowwise", "mxfp8*", "nvfp4*", "mxfp4*"} + + match recipe_name: + case "tensorwise": + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + # kernel 3: read in bf16, write in float8 + kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_3_rw] + + case "rowwise": + # x_bf16 = ... + # kernel 1: x_bf16 -> x_fp8 (with per-row scaling) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + # add in the bytes for scale writes + kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0 + res_bytes = [kernel_1_rw] + + case name if name and name.startswith("mxfp8"): + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + # add in the bytes for scale writes in E8M0 format + kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // 32) + res_bytes = [kernel_1_rw] + + case name if name and (name.startswith("mxfp4") or name.startswith("nvfp4")): + # For NVFP4, assume minimal overhead since it's primarily a compute format + # x_bf16 = ... + # kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel + if name.startswith("nvfp4"): + kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor + # add in the bytes for scale writes in E4M3 | E8M0 + block_size = 32 if name.startswith("mxfp4") else 16 + kernel_1_rw += BYTES_PER_EL_FLOAT8 * dim0 * (dim1 // block_size) + res_bytes = [kernel_1_rw] + + case _: + raise ValueError( + f"Unknown recipe name: {recipe_name}. " + f"Allowed recipes: {allowed_recipes}" + ) # convert from bytes to seconds res_s = [ @@ -414,7 +452,7 @@ def get_inference_float8_mem_sympy( M, K, N, - float8_recipe_name: Optional[str], + recipe_name: Optional[str], gpu_name: Optional[str] = None, ): specs = get_specs(gpu_name) @@ -425,7 +463,7 @@ def get_inference_float8_mem_sympy( M, K, tensor_role="input", - float8_recipe_name=float8_recipe_name, + recipe_name=recipe_name, fuse_with_prev=False, ) res = sum([*fwd_fp8_input_mem]) @@ -437,11 +475,12 @@ def get_inference_gemm_time_sympy( K: sympy.Symbol, N: sympy.Symbol, dtype, - float8_recipe_name: Optional[str], - gpu_name: Optional[str], + recipe_name: Optional[str], + gpu_name: Optional[str] = None, ): - assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported" # note: this function is currently not super accurate for small shapes: # when M,K,N <= 1k,1k,1k it undercounts by around 2x - gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name) + gemm_output_time_s = get_individual_gemm_time_sympy( + M, K, N, dtype, recipe_name, gpu_name + ) return gemm_output_time_s