Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_gemm_times(
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
else:
assert False, "TODO add cutlass mx gemm here"
assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}"

def do_matmul(A, B):
return torch._scaled_mm(
Expand Down Expand Up @@ -233,6 +233,20 @@ def run(
print(f"mx_recipe_name: {mx_recipe_name}")
print(f"enable_fusion_modeling: {enable_fusion_modeling}")

assert mx_recipe_name in (
# real mxfp8_cublas recipe
"mxfp8_cublas",
# real mxfp8_cublas_rceil recipe
"mxfp8_cublas_rceil",
# modeling of what mxfp8 with 32x32 block size and without gemm
# operand layout restrictions would look like
"mxfp8_32x32_flexible_gemm_layout",
# modeling of what mxfp8 with 32x32 block size for weight
"mxfp8_32x32_weight",
# real mxfp4_cutlass recipe
"mxfp4_cutlass",
), f"unsupported {mx_recipe_name=}"

M, K, N = sympy.symbols("M K N")

fp8_ovhd_time_sympy = get_float8_mem_sympy(
Expand Down Expand Up @@ -309,7 +323,11 @@ def run(
rb_fp8_gemm_ratio = -1

if do_benchmarks:
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
assert mx_recipe_name not in (
"mxfp4_cutlass",
"mxfp8_32x32_flexible_gemm_layout",
"mxfp8_32x32_weight",
), f"do_benchmarks unsupported with {mx_recipe_name=}"

# TODO(future): make the bf16 gemm times exactly match the e2e
# benchmarks, there is a slight deviation, probably related to gemm
Expand Down
42 changes: 41 additions & 1 deletion torchao/testing/training/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,53 @@ def get_tensor_memory_traffic_ovhd_s(
else:
assert False, "unsupported"

elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout":
# modeling the following:
# 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense
# across dim0 and dim1
# 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in
# PyTorch right now)
# x_bf16 = ...
# kernel 1: x_bf16 -> x_mxfp8_dim0
if fuse_with_prev:
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
else:
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
res_bytes = [kernel_1_rw]

elif mx_recipe_name == "mxfp8_32x32_weight":
# modeling the following:
# 1. mxfp8 scaling with 32x32 weights, so the format makes sense
# across dim0 and dim1. input and grad_output still 1x32.

if tensor_role in ("input", "grad_output"):
# TODO(future): update all of the mx rooflines to just read once
# kernel 1: x_bf16 -> x_mxfp8_dim0
# kernel 2: x_bf16 -> x_mxfp8_dim1
if fuse_with_prev:
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
else:
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel

elif tensor_role == "weight":
# kernel 1: x_bf16 -> x_mxfp8_dim0
# kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2

else:
assert False, "unsupported"

res_bytes = [kernel_1_rw, kernel_2_rw]

else:
assert mx_recipe_name in (
"mxfp8_emulated",
"mxfp8_cublas",
"mxfp8_cublas_rceil",
"mxfp4_cutlass",
), "unsupported"
), f"unsupported {mx_recipe_name=}"
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
# x_bf16 = ...
# kernel 1: x_bf16 -> x_mxfp8_dim0
Expand Down
Loading