Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
PerRow,
Expand Down Expand Up @@ -134,12 +135,18 @@ def get_gemm_times(
elif recipe_name == "mxfp8_cublas":
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_a = to_blocked(scale_a)
scale_b = to_blocked(scale_b)
elif recipe_name == "mxfp4_cutlass":
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_a = to_blocked(scale_a)
scale_b = to_blocked(scale_b)
elif recipe_name == "nvfp4":
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
scale_a = to_blocked(scale_a)
scale_b = to_blocked(scale_b)

else:
assert False, "unsupported"
Expand All @@ -166,6 +173,9 @@ def run(
recipe_name: str,
do_benchmarks: bool = True,
shape_gen_name: str = "pow2",
M: Optional[int] = None,
K: Optional[int] = None,
N: Optional[int] = None,
n_limit: Optional[int] = None,
save_profile_traces: bool = False,
enable_fusion_modeling: bool = False,
Expand All @@ -174,7 +184,8 @@ def run(
Args:
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
* `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
# `save_profile_traces (optional)`: if True, saves profiling traces
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
Expand All @@ -187,9 +198,13 @@ def run(
["do_benchmarks", do_benchmarks],
["shape_gen_name", shape_gen_name],
["enable_fusion_modeling", enable_fusion_modeling],
["MKN", f"{M} {K} {N}"],
]
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))

# reassign user specified MKN, so we can use them for sympy
user_M, user_K, user_N = M, K, N

M, K, N = sympy.symbols("M K N")

fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
Expand Down Expand Up @@ -245,7 +260,7 @@ def run(
]
results = []

name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None)
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N)

for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)):
if n_limit is not None and idx >= n_limit:
Expand Down
Loading